问题描述
我尝试预训练一个 seq-to-seq 编码器解码器模型,该模型非常接近此处列出的结构 (https://www.tensorflow.org/tutorials/text/nmt_with_attention),但使用不同的数据集。我计划将此模型作为 GAN 模型中的生成器,这意味着我需要以可微分的方式从模型中采样文本(即我不能使用任何 argmax 函数)。我注意到 Gumbel Sampling 是一个潜在的解决方案,并且 Tensorflow 在 tfp.distributions.RelaxedOneHotCategorical 类中实现了 Gumbel Sampling。但是,当我在模型中实现该类时,训练似乎不再收敛。解码器类的写法如下:
class Decoder(tf.keras.Model):
def __init__(self,vocab_size,embedding_dim,dec_units,batch_sz,temp):
super(Decoder,self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_dim)
self.gru = tf.keras.layers.GRU(self.dec_units,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform')
self.fc = tf.keras.layers.Dense(vocab_size)
self.gumbel = tfp.layers.distributionLambda(lambda logits: tfp.distributions.RelaxedOneHotCategorical(logits=logits,temperature=temp),convert_to_tensor_fn=lambda s: s.sample())
self.temp = temp
self.attention = BahdanauAttention(self.dec_units)
def __call__(self,x,hidden,enc_output):
context_vector,attention_weights = self.attention(hidden,enc_output)
x = self.embedding(x)
x = tf.concat([tf.expand_dims(context_vector,1),x],axis=-1)
output,state = self.gru(x)
output = tf.reshape(output,(-1,output.shape[2]))
x = self.fc(output)
x = self.gumbel(x)
x_hard = tf.stop_gradient(tf.cast(tf.equal(x,tf.reduce_max(x,1,keepdims=True)),x.dtype))
x = tf.stop_gradient(x_hard - x) + x
return x,state,attention_weights
call 方法的最后几行使用 argmax 为前向传递生成单热输出,但由于 stop_gradient(x_hard - x)
被调用,Tensorflow 应忽略后向传递中的 argmax 并计算梯度 w.r.t.到 RelaxedOneHotCategorical。但是,我仍然遇到训练收敛问题
x_hard = tf.stop_gradient(tf.cast(tf.equal(x,x.dtype))
x = tf.stop_gradient(x_hard - x) + x
省略了行。
即使我尝试像下面这样的简单小例子,我看到返回的梯度仍然是 0。
@tf.function
def test(a):
b = tf.nn.sigmoid(a)
c = tfp.distributions.RelaxedOneHotCategorical(logits=b,temperature=.1).sample()
return tf.gradients(c,[a,b])
a = tf.Variable(np.array([.1,.2,.7]))
test(a)
有人知道如何解决这个问题吗?作为参考,我使用的是 Tensorflow 2.3 和 Tensorflow Probability 0.11.1
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)