问题描述
在对 pymc3
多元正态分布取消居中后,我在 beta
中的跟随模型遇到了一个奇怪的错误。
# specify model and sample
with pm.Model() as hp_model:
# hyper priors
chol,corr,stds = pm.LKJCholeskyCov("Omega",n=p,eta=1.,sd_dist=pm.Exponential.dist(2.0),compute_corr=True)
cov = pm.Deterministic("cov",chol.dot(chol.T))
mu_alpha = pm.normal("mu_alpha",mu=12.,sigma=1.)
tau_alpha = pm.HalfStudentT("tau_alpha",sigma=0.5,nu=1.)
# priors
alpha_uc = pm.normal("alpha_uc",mu=0.,sigma=1.)
alpha = pm.Deterministic("alpha",alpha_uc * tau_alpha + mu_alpha)
beta_uc = pm.Mvnormal("beta_uc",mu=zp_vec,cov=Ip_mat,shape=(p,))
beta = pm.Deterministic("beta",cov.dot(beta_uc))
sigma = pm.HalfStudentT("sigma",nu=1.)
# likelihood
Ey_x = alpha + X_train.dot(beta) # expectation of Y|X
y_obs = pm.normal("y_obs",mu=Ey_x,sigma=sigma,observed=y_train)
此模型一直有效,直到我取消了 beta 变量的中心位置。现在行 Ey_x = alpha + X_train.dot(beta)
抛出以下错误:
Traceback (most recent call last):
File "house_prices_pipeline.py",line 70,in <module>
y_obs = pm.normal("y_obs",observed=y_train)
File "/home/samvoisin/anaconda3/envs/house_price_pymc3/lib/python3.8/site-packages/pymc3/distributions/distribution.py",line 121,in __new__
dist = cls.dist(*args,**kwargs)
File "/home/samvoisin/anaconda3/envs/house_price_pymc3/lib/python3.8/site-packages/pymc3/distributions/distribution.py",line 130,in dist
dist.__init__(*args,**kwargs)
File "/home/samvoisin/anaconda3/envs/house_price_pymc3/lib/python3.8/site-packages/pymc3/distributions/continuous.py",line 487,in __init__
self.mean = self.median = self.mode = self.mu = mu = tt.as_tensor_variable(floatX(mu))
File "/home/samvoisin/anaconda3/envs/house_price_pymc3/lib/python3.8/site-packages/pymc3/theanof.py",line 83,in floatX
return X.astype(theano.config.floatX)
ValueError: setting an array element with a sequence.
我相信错误是由于 Ey_x
是张量而不是标量,但是,我很难在 Theano 的文档中找到解决方案。我试过了
Ey_x = T.cast(alpha + X_train.dot(beta),"float32").eval()
但收到类型错误:TypeError: Unsupported dtype for TensorType: object
是否有更好的方法使用 Theano 计算 Ey_x
?
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)