在 JAX 中计算词向量移动平均值的最佳方法

问题描述

假设我有一个形状为 W 的矩阵 (n_words,model_dim),其中 n_words 是句子中的单词数,model_dim 是单词所在空间的维度表示向量。计算这些向量的移动平均值的最快方法是什么?

例如,如果窗口大小为 2(窗口长度 = 5),我可能会有这样的事情(这会引发错误 TypeError: JAX 'Tracer' objects do not support item assignment):

from jax import random
import jax.numpy as jnp

# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0),shape=(17,32)) 

ws = 2          # window size
N = W.shape[0]  # number of words

new_W = jnp.zeros(W.shape)

for i in range(N):
    window = W[max(0,i-ws):min(N,i+ws+1)]
    n = window.shape[0]
    for j in range(n):
        new_W[i] += W[j] / n

我想 jnp.convolve一个更快的解决方案,但我不熟悉它。

解决方法

这看起来您正在尝试进行卷积,因此 jnp.convolve 或类似方法可能是一种更高效的方法。

也就是说,您的示例有点奇怪,因为 n 永远不会大于 4,因此除了 W 的前四个元素之外,您永远不会访问任何其他元素。此外,您在内循环的每次迭代中都覆盖了前一个值,因此 new_W 的每一行只包含 W 前四行之一的缩放副本。

将您的代码更改为我认为您的意思并使用 index_update 使其与 JAX 的不可变数组兼容:

from jax import random
import jax.numpy as jnp

# Fake word vectors (17 words vectors of dimension 32)
W = random.normal(random.PRNGKey(0),shape=(17,32)) 

ws = 2          # window size
N = W.shape[0]  # number of words

new_W = jnp.zeros(W.shape)

for i in range(N):
    window = W[max(0,i-ws):min(N,i+ws)]
    n = window.shape[0]
    for j in range(n):
      new_W = new_W.at[i].add(window[j] / n)

这里是更有效的卷积方面的等价物:

from jax.scipy.signal import convolve
kernel = jnp.ones((4,1))
new_W_2 = convolve(W,kernel,mode='same') / convolve(jnp.ones_like(W),mode='same')

jnp.allclose(new_W,new_W_2)
# True

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...