张量流图形模式下 tfp.distributions.Categorical.log_prob 的解决方法/后备值

问题描述

如果输入的标签超出范围,是否有办法避免tfp.distributions.Categorical.log_prob引发错误

我将一批样本传递给 log_prob 方法,其中一些样本的值为 n_categories + 1,这是当您从全零概率分布中采样时得到的回退值。我的 probs 批次中的一些概率分布全为零**。

dec_output,h_state,c_state = self.decoder(dec_inp,[h_state,c_state])
probs = self.attention(enc_output,dec_output,pointer_mask,len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action

在这些情况下,我不在乎从 log_prob 获得什么价值,因为稍后我将掩盖它而不使用它。不确定是否可以以某种方式实现 fallback 值。如果没有,是否有任何解决方法可以避免在我以图形模式(使用 @tf.function)执行时引发错误

**这是因为我正在使用 RNN 进行随机解码,该 RNN 是多批可变长度序列,一个 seq 到 seq 任务。

解决方法

如果您可以屏蔽 log_prob,您也可以将 probs 屏蔽为 1 / n。 请注意,使用 Categorical 的 logits 参数化并删除(大概)上游 softmax 激活在数值上更稳定。