问题描述
运行 y = multivariate_normal(np.zeros(d),np.eye(d)).rvs()
时,我们获得维度为 (d,)
的样本。然而,当 d=1
获得一个标量时,这是有道理的,因为它是一维的。不幸的是,我有一些代码必须适用于任意数量的维度,包括 d=1
,并且基本上采用 d
维向量 x
与 y
的点积.这在 d=1
处中断。我该如何解决?
import numpy as np
from scipy.stats import multivariate_normal as MVN
def mwe_function(d,x):
"""Minimal Working Example"""
y = MVN(np.zeros(d),np.eye(d)).rvs()
return x @ y
mwe_function(2,np.ones(2)) # This works
mwe_function(1,np.ones(1)) # This doesn't
重要提示:我想避免使用 if 语句。在这种情况下可以简单地使用 scipy.stats.norm
,但我想避免使用 if 语句,因为它们会减慢代码速度。
解决方法
您可以使用 np.reshape
来固定样品的形状。通过使用 -1
指定第一维的长度,您将始终得到一个一维数组,没有标量。
import numpy as np
from scipy.stats import multivariate_normal as MVN
def mwe_function(d,x):
"""Minimal Working Example"""
y = MVN(np.zeros(d),np.eye(d)).rvs().reshape([-1])
return x @ y
v0 = mwe_function(2,np.ones(2)) # This works
print(v0) # -0.5718013906409207
v1 = mwe_function(1,np.ones(1)) # This works as well :-)
print(v1) # -0.20196038784485093
.reshape([-1])
在哪里工作。
就个人而言,我更喜欢重塑而不是使用 np.atleast_1d
,因为效果是直接可见的 - 但最终它是一个品味问题。