使用 scipy.integrate.nquad 的双积分解决方案与integrate.dblquad 不匹配

问题描述

下面代码中的一个函数使用与scipy.integrate.dblquad的双重积分来计算copula密度函数c*np.log(c)的微分熵c,它有一个依赖参数 theta,通常为正值。

enter image description here

下面代码中的第二个函数试图解决与上面相同的问题,但使用了多重积分求解器scipy.integrate.nquad

from scipy import integrate
import numpy as np

def dblquad_(theta):
    "Double integration"
    c = lambda v,u: ((1+theta)*(u*v)**(-1-theta)) * (u**(-theta)+v**(-theta)-1)**(-1/theta-2)
    return -integrate.dblquad(
        lambda u,v: c(v,u)*np.log(c(v,u)),1,lambda u: 0,lambda u: 1
        )[0]

def nquad_(n,theta):
    "Multiple integration"
    c = lambda *us: ((1+theta)*np.prod(us)**(-1-theta)) * (np.sum(np.power(us,-theta))-1)**(-1/theta-2)
    return -integrate.nquad(
        func   = lambda *us : c(*us)*np.log(c(*us)),ranges = [(0,1) for i in range(n)],args   = (theta,) 
        )[0] 

n=2
theta = 1
print(dblquad_(theta))
print(nquad_(n,theta))

基于 dblquad函数给出了 -0.7127 的答案,而 nquad 给出了 -0.5823,并且明显需要更长的时间。为什么即使我已将两者都设置为解决 n=2 维问题,但解决方案却不同?

解决方法

使用您提供的 ntheta 值,您的代码输出为:

-0.1931471805597395
0.17055845832017144,

不是-0.7127-0.5823

第一个值 (-0.1931471805597395) 是正确的(您可以自己检查 here)。

nquad_ 中的问题在于 theta 参数的处理。感谢@mikuszefski 提供解释;为了清楚起见,我将其复制到此处:

nquad 根据需要将 lambda 传递给函数。拉姆达是 以这样一种方式编程,它接受任意数量的 参数,所以它很高兴地接受它并把它放在权力列表中 和总和。因此你没有得到,例如1/u**t+1/v**t -1 但是 1/u**t+1/v**t + 1/t**t -1。函数调用只是不匹配 预期的功能用途。如果您改为写 us[0]**() + us[1]**() - 1,它会起作用。

修改后的代码如下:

from scipy import integrate
import numpy as np

def dblquad_(theta):
    "Double integration"
    c = lambda v,u: ((1+theta)*(u*v)**(-1-theta)) * (u**(-theta)+v**(-theta)-1)**(-1/theta-2)
    return -integrate.dblquad(
        lambda u,v: c(v,u)*np.log(c(v,u)),1,lambda u: 0,lambda u: 1
        )[0]

def nquad_(n,theta):
    "Multiple integration"
    c = lambda *us: ((1+theta)*np.prod((us[0],us[1]))**(-1-theta)) * (np.sum(np.power((us[0],us[1]),-theta))-1)**(-1/theta-2)
    return -integrate.nquad(
        func   = lambda *us : c(*us)*np.log(c(*us)),ranges = [(0,1) for i in range(n)],args=(theta,)
        )[0]

n=2
theta = 1
print(dblquad_(theta))
print(nquad_(n,theta))

输出:

-0.1931471805597395
-0.1931471805597395

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...