问题描述
我正在尝试绘制 3D 图形,使用重新存在的函数生成 Z 值。但是,这会产生错误“具有多个元素的数组的真值不明确”。这看起来很奇怪,因为我能够使用相同的函数和 y,x 值生成 Z 值列表,但是一旦我包含 3D 图形代码,就会发生错误。
我的绘图代码是:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib import cm
def f(tau,tau_b): #re-use society welfare function of tau & tau_b,using corr=0.6
Z = society_welfare2 (0.6,tau,tau_b)
return Z
xgrid=np.linspace(1e-5,1-1e-5,100) #tau grid
ygrid=np.linspace(1e-5,100) #tau_b grid
tau,tau_b=np.meshgrid(xgrid,ygrid)
fig=plt.figure(figsize=(8,6))
ax=fig.add_subplot(111,projection='3d')
ax.plot_surface(tau,tau_b,f(tau,tau_b),rstride=2,cstride=2,cmap=cm.jet,alpha=0.7,linewidth=0.25)
ax.set_zlim(-0.5,1.0)
plt.show()
def society_welfare2 (corr,tau_b):
cov = [[1,corr],[corr,1]] #covariance
epsilon_start,b_start = np.random.multivariate_normal(mean,cov,sample_N).T
epsilon = np.exp(epsilon_start) #to ensure epsilon positive
b = np.exp(b_start) #to ensure b positive
indv_welfares = []
def GBC (t_o):
taxes_paid = []
for i in range(sample_N): #loop over all agents to find their C1,C2,L
def consumption_functions(Lguess,epsilon=epsilon,b=b):
C2 = (((1-tau)*epsilon[i]*w*Lguess) +(1-tau_b)*b[i] + ((t_o)/(1+r)))/((1/((beta**(1/gamma))*((1+r)**(1/gamma)))) + (1/(1+r)))
C1 = C2 /((beta**(1/gamma))*(1+r)**(1/gamma))
return -Utility(C1,Lguess)
result = minimize_scalar(consumption_functions,bounds=(0,1),method='bounded',args=(epsilon,b))
opt_L = result.x
opt_C1=(((1-tau)*(epsilon[i])*w)/(opt_L**sigma))**(1/gamma)
opt_C2=(opt_C1)*((beta**(1/gamma))*(1+r)**(1/gamma))
income_tax = tau*(epsilon[i])*w*opt_L
bequest_tax = tau_b*(b[i])
taxes_paid.append(income_tax)
taxes_paid.append(bequest_tax)
welfare_func = opt_C1**(1-gamma)/(1-gamma)-opt_L**(1+sigma)/(1+sigma) + beta*(opt_C2**(1-gamma)/(1-gamma))
indv_welfares.append(welfare_func)
total_tax_revenue = sum(taxes_paid)
return total_tax_revenue - (10000*t_o)
result1 = minimize_scalar(GBC,bounds=(1e-5,100000),method='bounded')
opt_t_o = result1.x
total_welfare = sum(indv_welfares)
return total_welfare
ValueError Traceback (most recent call last)
<ipython-input-19-3633f4a9db76> in <module>
18 ax.plot_surface(tau,19 tau_b,---> 20 f(tau,21 rstride=2,22 cmap=cm.jet,<ipython-input-19-3633f4a9db76> in f(tau,tau_b)
7
8 def f(tau,using corr=0.6
----> 9 Z = society_welfare2 (0.6,tau_b)
10 return Z
11
<ipython-input-17-321a709b9684> in society_welfare2(corr,tau_b)
61 return total_tax_revenue - (10000*t_o)
62
---> 63 result1 = minimize_scalar(GBC,method='bounded')
64
65 opt_t_o = result1.x
/opt/anaconda3/lib/python3.8/site-packages/scipy/optimize/_minimize.py in minimize_scalar(fun,bracket,bounds,args,method,tol,options)
798 if isinstance(disp,bool):
799 options['disp'] = 2 * int(disp)
--> 800 return _minimize_scalar_bounded(fun,**options)
801 elif meth == 'golden':
802 return _minimize_scalar_golden(fun,**options)
/opt/anaconda3/lib/python3.8/site-packages/scipy/optimize/optimize.py in _minimize_scalar_bounded(func,xatol,maxiter,disp,**unkNown_options)
1956 rat = e = 0.0
1957 x = xf
-> 1958 fx = func(x,*args)
1959 num = 1
1960 fmin_data = (1,xf,fx)
<ipython-input-17-321a709b9684> in GBC(t_o)
41 return -Utility(C1,Lguess)
42
---> 43 result = minimize_scalar(consumption_functions,b))
44
45 opt_L = result.x
/opt/anaconda3/lib/python3.8/site-packages/scipy/optimize/_minimize.py in minimize_scalar(fun,**unkNown_options)
2015 print("%5.0f %12.6g %12.6g %s" % (fmin_data + (step,)))
2016
-> 2017 if fu <= fx:
2018 if x >= xf:
2019 a = xf
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
解决方法
回溯的最低点
if fu <= fx:
这是比较 if
中的两个变量。如果 fu
和 fx
是标量或单值数组,这将起作用。但如果其中一个是多值数组,则将引发此错误。
此时我们的任务是将这些变量追溯到您的代码。我怀疑您正在为某些参数提供数组,其中它们应该是标量。
看着顶部。当您要求绘图时会发生这种情况,但参数是函数调用:
f(tau,tau_b)
并通过函数调用 minimize
函数上的 GBC
。我认为 GBC
是 func
中的:
fx = func(x,*args)
这就提出了一个问题,GBC
到底返回什么?它在 _minimize_scalar
中使用,因此它应该只返回一个值。
它的返回表达式是什么?
return total_tax_revenue - (10000*t_o)
您认为您可以从那里进行分析吗?
现在您明白为什么我们坚持要看到 traceback
。错误出在您的代码中,但到达那里的序列很长,并且仅通过阅读代码并不明显。
编辑
糟糕,我看到另一个级别的 minimize
,一个使用了
consumption_functions
它有几个参数,epsilon
和 b
。我想我们可以推断出那些是什么。但是什么是
Utility
fu <= fx
似乎是针对绑定的 fx
测试 fu
返回值。假设边界是标量,那么值 fx
必须是一个数组。是吗???