使用 vmap (jax) 按元素求和矩阵?

问题描述

我正在尝试了解 vmap 中的 in_axes 和 out_axes 选项。 例如,我想对两个矩阵求和并得到形状相同的输出

X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)
def sum2(x,y):
    return x + y
vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)

我想我分别为 X 和 Y 映射了轴 0 和 1。输出将具有与 X、Y 相同的形状。 但我得到了错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-403-103694166574> in <module>
      3 def sum2(x,y):
      4     return x + y
----> 5 vmap(sum2,Y)

    [... skipping hidden 2 frame]

~/anaconda3/lib/python3.8/site-packages/jax/api_util.py in flatten_axes(name,treedef,axis_tree,kws)
    276       assert treedef_is_leaf(leaf)
    277       axis_tree,_ = axis_tree
--> 278     raise ValueError(f"{name} specification must be a tree prefix of the "
    279                      f"corresponding value,got specification {axis_tree} "
    280                      f"for value tree {treedef}.") from None

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value,got specification ((0,1)) for value tree PyTreeDef((*,*)).

解决方法

首先,按元素求和的最简单方法是使用内置的二元运算广播,并直接调用 sum2(X,Y)

也就是说,如果您想了解 vmap:问题是 vmap 一次只能映射一个轴。如果要映射多个轴,可以嵌套多个vmap。我相信你的意图可以这样表达:

from jax import vmap
import jax.numpy as np

X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)

def sum2(x,y):
    assert x.ndim == y.ndim == 0
    return x + y

vmap(vmap(sum
  vmap(sum2,in_axes=(0,0),out_axes=0),in_axes=(1,1),out_axes=1
)(X,Y)

注意:我添加了关于维数的断言,以证明正在对标量值调用映射函数。

另外,请注意当映射的轴匹配时,例如in_axes=(0,0) 可以等效地写为 in_axes=0,但我将其保留为元组,因为它更接近您尝试的语法。

事实上,使用嵌套的 vmap 进行相同计算的更简洁的方法是使用默认参数:vmap(vmap(sum2))(X,Y) 将执行相同的元素求和。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...