从 JAX 中的多元正态分布采样给出类型错误

问题描述

我正在尝试使用 JAX 从多元正态分布中生成样本:

import jax
import jax.numpy as jnp
import numpy as np

key = random.PRNGKey(0)
cov = np.array([[1.2,0.4],[0.4,1.0]])
mean = np.array([3,-1])
x1,x2 = jax.random.multivariate_normal(key,mean,cov,5000).T

但是,当我运行代码时,出现以下错误

TypeError                                 Traceback (most recent call last)
<ipython-input-25-1397bf923fa4> in <module>()
      2 cov = np.array([[1.2,1.0]])
      3 mean = np.array([3,-1])
----> 4 x1,5000).T

1 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in canonicalize_shape(shape)
   1159          "got {}.")
   1160   if any(isinstance(x,Tracer) and isinstance(get_aval(x),ShapedArray)
-> 1161          and not isinstance(get_aval(x),ConcreteArray) for x in shape):
   1162     msg += ("\nIf using `jit`,try using `static_argnums` or applying `jit` to "
   1163             "smaller subfunctions.")

TypeError: 'int' object is not iterable

我不确定问题是什么,因为相同的语法适用于 Numpy 中的等效函数

解决方法

jax.random 模块中,大多数形状必须明确是元组。因此,不要使用形状 5000,而是使用 (5000,)

x1,x2 = jax.random.multivariate_normal(key,mean,cov,(5000,)).T

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...