Jax、jit 和动态形状:来自 Tensorflow 的回归?

问题描述

documentation for JAX 说,

并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。

现在我有点惊讶,因为 tensorflow 有 tf.boolean_mask 之类的操作,可以完成 JAX 在编译时似乎无法做到的事情。

  1. 为什么 Tensorflow 会出现这种回归?我假设底层 XLA 表示在两个框架之间共享,但我可能错了。我不记得 Tensorflow 曾在动态形状方面遇到过麻烦,而且 tf.boolean_mask 之类的函数一直存在。
  2. 我们能否期待这种差距在未来缩小?如果没有,为什么在 JAX 的 jit 中无法实现 Tensorflow(以及其他)支持功能

编辑

梯度通过tf.boolean_mask(显然不是在mask值上,它们是离散的);这里使用 TF1 样式的图形,其中值未知,因此 TF 不能依赖它们:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x1 = tf.placeholder(tf.float32,(3,))
x2 = tf.placeholder(tf.float32,))
y = tf.boolean_mask(x1,x2 > 0)
print(y.shape)  # prints "(?,)"
dydx1,dydx2 = tf.gradients(y,[x1,x2])
assert dydx1 is not None and dydx2 is None

解决方法

我认为 JAX 并没有比 TensorFlow 更无能做到这一点。没有禁止您在 JAX 中执行此操作:

new_array = my_array[mask]

但是,mask 应该是索引(整数)而不是布尔值。这样,JAX 就知道 new_array 的形状(与 mask 相同)。从这个意义上说,我很确定 tf.boolean_mask 是不可微的,即如果您尝试在某个时刻计算其梯度,它会引发错误。

更一般地,如果您需要屏蔽一个数组,无论您使用的是什么库,都有两种方法:

  1. 如果您事先知道需要选择哪些索引并且您需要提供这些索引,以便库可以在编译前计算形状;
  2. 如果您无法定义这些索引,无论出于何种原因,您都需要设计代码以避免填充影响您的结果。

每种情况的示例

  1. 假设您正在 JAX 中编写一个简单的嵌入层。 input 是与几个句子对应的一批标记索引。为了获得与这些索引对应的词嵌入,我将简单地写成 word_embeddings = embeddings[input]。由于我事先不知道句子的长度,所以我需要事先将所有标记序列填充到相同的长度,这样 input 的形状为 (number_of_sentences,sentence_max_length)。现在,每次此形状更改时,JAX 都会编译屏蔽操作。为了尽量减少编译次数,您可以提供相同数量的句子(也称为批量大小),并且可以将 sentence_max_length 设置为整个语料库中的最大句子长度。这样,在训练期间将只有一个编译。当然,您需要在 word_embeddings 中保留与 pad 索引对应的一行。但是,掩蔽仍然有效。

  2. 在模型的后面,假设您想将每个句子的每个单词表达为句子中所有其他单词的加权平均值(如自注意力机制)。对整个批次并行计算权重,并存储在维度 A 的矩阵 (number_of_sentences,sentence_max_length,sentence_max_length) 中。加权平均值使用公式 A @ word_embeddings 计算。现在,您需要确保 pad 标记不会影响之前的公式。为此,您可以将对应于填充索引的 A 条目归零,以消除它们对平均的影响。如果 pad 标记索引为 0,你会这样做:

    mask = jnp.array(input > 0,dtype=jnp.float32)
    A = A * mask[:,jnp.newaxis,:]
    weighted_mean = A @ word_embeddings 

所以这里我们使用了一个布尔掩码,但掩码在某种程度上是可微的,因为我们将掩码与另一个矩阵相乘,而不是将其用作索引。请注意,我们应该继续以相同的方式删除 weighted_mean 中也对应于 pad 标记的行。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...