问题描述
我想在 tensorflow 中实现一个自定义的激活函数。这个激活函数的想法是它应该学习它的线性程度。使用以下函数。
tanh(x*w)/w for w!= 0
x for w = 0
应该学习参数w。但是我不知道如何在 tensorflow 中实现这一点。
解决方法
激活函数只是模型的一部分,所以这里是您描述的函数的代码。
import tensorflow as tf
from tensorflow.keras import Model
class MyModel(Model):
def __init__(self):
super().__init__()
# Some layers
self.W = tf.Variable(tf.constant([[0.1,0.1],[0.1,0.1]]))
def call(self,x):
# Some transformations with your layers
x = tf.where(x==0,x,tf.tanh(self.W*x)/self.W)
return x
所以,对于非零矩阵 MyModel()(tf.constant([[1.0,2.0],[3.0,4.0]]))
它返回
<tf.Tensor: shape=(2,2),dtype=float32,numpy=
array([[0.9966799,1.9737529],[2.913126,3.79949 ]],dtype=float32)>
对于零矩阵 MyModel()(tf.constant([[0.0,0.0],[0.0,0.0]]))
它返回零
<tf.Tensor: shape=(2,numpy=
array([[0.,0.],[0.,0.]],dtype=float32)>