3D 图形错误:“具有多个元素的数组的真值不明确”

问题描述

我正在尝试绘制 3D 图形,使用重新存在的函数生成 Z 值。但是,这会产生错误“具有多个元素的数组的真值不明确”。这看起来很奇怪,因为我能够使用相同的函数和 y,x 值生成 Z 值列表,但是一旦我包含 3D 图形代码,就会发生错误

我的绘图代码是:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib import cm
 
def f(tau,tau_b):                             #re-use society welfare function of tau & tau_b,using corr=0.6
    Z = society_welfare2 (0.6,tau,tau_b)
    return Z

xgrid=np.linspace(1e-5,1-1e-5,100)   #tau grid
ygrid=np.linspace(1e-5,100)   #tau_b grid
tau,tau_b=np.meshgrid(xgrid,ygrid)

fig=plt.figure(figsize=(8,6))
ax=fig.add_subplot(111,projection='3d')
ax.plot_surface(tau,tau_b,f(tau,tau_b),rstride=2,cstride=2,cmap=cm.jet,alpha=0.7,linewidth=0.25)
ax.set_zlim(-0.5,1.0)
plt.show()

我的social_welfare2函数代码

def society_welfare2 (corr,tau_b):
    
    cov   = [[1,corr],[corr,1]]   #covariance
    epsilon_start,b_start = np.random.multivariate_normal(mean,cov,sample_N).T 

    epsilon     = np.exp(epsilon_start)  #to ensure epsilon positive
    b     = np.exp(b_start)              #to ensure b positive

    indv_welfares = []
    
    def GBC (t_o):
        taxes_paid = []

        for i in range(sample_N):                     #loop over all agents to find their C1,C2,L

            def consumption_functions(Lguess,epsilon=epsilon,b=b):
                C2 = (((1-tau)*epsilon[i]*w*Lguess) +(1-tau_b)*b[i] + ((t_o)/(1+r)))/((1/((beta**(1/gamma))*((1+r)**(1/gamma)))) + (1/(1+r)))
                C1 = C2 /((beta**(1/gamma))*(1+r)**(1/gamma))
                return -Utility(C1,Lguess)

            result = minimize_scalar(consumption_functions,bounds=(0,1),method='bounded',args=(epsilon,b))

            opt_L = result.x
        
            opt_C1=(((1-tau)*(epsilon[i])*w)/(opt_L**sigma))**(1/gamma)
        
            opt_C2=(opt_C1)*((beta**(1/gamma))*(1+r)**(1/gamma))
        
            income_tax = tau*(epsilon[i])*w*opt_L         
            bequest_tax = tau_b*(b[i])                 
            taxes_paid.append(income_tax)        
            taxes_paid.append(bequest_tax)   
            
            welfare_func = opt_C1**(1-gamma)/(1-gamma)-opt_L**(1+sigma)/(1+sigma) + beta*(opt_C2**(1-gamma)/(1-gamma))
            indv_welfares.append(welfare_func)
    
        total_tax_revenue = sum(taxes_paid)  
    
        return total_tax_revenue - (10000*t_o)

    result1 = minimize_scalar(GBC,bounds=(1e-5,100000),method='bounded')
    
    opt_t_o = result1.x
    
    total_welfare = sum(indv_welfares)
    return total_welfare   

完整的回溯错误代码

ValueError                                Traceback (most recent call last)
<ipython-input-19-3633f4a9db76> in <module>
     18 ax.plot_surface(tau,19                 tau_b,---> 20                 f(tau,21                 rstride=2,22                 cmap=cm.jet,<ipython-input-19-3633f4a9db76> in f(tau,tau_b)
      7 
      8 def f(tau,using corr=0.6
----> 9     Z = society_welfare2 (0.6,tau_b)
     10     return Z
     11 

<ipython-input-17-321a709b9684> in society_welfare2(corr,tau_b)
     61         return total_tax_revenue - (10000*t_o)
     62 
---> 63     result1 = minimize_scalar(GBC,method='bounded')
     64 
     65     opt_t_o = result1.x

/opt/anaconda3/lib/python3.8/site-packages/scipy/optimize/_minimize.py in minimize_scalar(fun,bracket,bounds,args,method,tol,options)
    798         if isinstance(disp,bool):
    799             options['disp'] = 2 * int(disp)
--> 800         return _minimize_scalar_bounded(fun,**options)
    801     elif meth == 'golden':
    802         return _minimize_scalar_golden(fun,**options)

/opt/anaconda3/lib/python3.8/site-packages/scipy/optimize/optimize.py in _minimize_scalar_bounded(func,xatol,maxiter,disp,**unkNown_options)
   1956     rat = e = 0.0
   1957     x = xf
-> 1958     fx = func(x,*args)
   1959     num = 1
   1960     fmin_data = (1,xf,fx)

<ipython-input-17-321a709b9684> in GBC(t_o)
     41                 return -Utility(C1,Lguess)
     42 
---> 43             result = minimize_scalar(consumption_functions,b))
     44 
     45             opt_L = result.x

/opt/anaconda3/lib/python3.8/site-packages/scipy/optimize/_minimize.py in minimize_scalar(fun,**unkNown_options)
   2015             print("%5.0f   %12.6g %12.6g %s" % (fmin_data + (step,)))
   2016 
-> 2017         if fu <= fx:
   2018             if x >= xf:
   2019                 a = xf

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

解决方法

回溯的最低点

if fu <= fx:

这是比较 if 中的两个变量。如果 fufx 是标量或单值数组,这将起作用。但如果其中一个是多值数组,则将引发此错误。

此时我们的任务是将这些变量追溯到您的代码。我怀疑您正在为某些参数提供数组,其中它们应该是标量。

看着顶部。当您要求绘图时会发生这种情况,但参数是函数调用:

f(tau,tau_b)

并通过函数调用 minimize 函数上的 GBC。我认为 GBCfunc 中的:

fx = func(x,*args)

这就提出了一个问题,GBC 到底返回什么?它在 _minimize_scalar 中使用,因此它应该只返回一个值。

它的返回表达式是什么?

return total_tax_revenue - (10000*t_o)

您认为您可以从那里进行分析吗?

现在您明白为什么我们坚持要看到 traceback。错误出在您的代码中,但到达那里的序列很长,并且仅通过阅读代码并不明显。

编辑

糟糕,我看到另一个级别的 minimize,一个使用了

consumption_functions

它有几个参数,epsilonb。我想我们可以推断出那些是什么。但是什么是

Utility

fu <= fx 似乎是针对绑定的 fx 测试 fu 返回值。假设边界是标量,那么值 fx 必须是一个数组。是吗???