问题描述
如果输入的标签超出范围,是否有办法避免tfp.distributions.Categorical.log_prob
引发错误?
我将一批样本传递给 log_prob
方法,其中一些样本的值为 n_categories + 1
,这是当您从全零概率分布中采样时得到的回退值。我的 probs
批次中的一些概率分布全为零**。
dec_output,h_state,c_state = self.decoder(dec_inp,[h_state,c_state])
probs = self.attention(enc_output,dec_output,pointer_mask,len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action
在这些情况下,我不在乎从 log_prob
获得什么价值,因为稍后我将掩盖它而不使用它。不确定是否可以以某种方式实现 fallback
值。如果没有,是否有任何解决方法可以避免在我以图形模式(使用 @tf.function
)执行时引发错误?
**这是因为我正在使用 RNN 进行随机解码,该 RNN 是多批可变长度序列,一个 seq 到 seq 任务。
解决方法
如果您可以屏蔽 log_prob,您也可以将 probs 屏蔽为 1 / n。 请注意,使用 Categorical 的 logits 参数化并删除(大概)上游 softmax 激活在数值上更稳定。