问题描述
N= [-7.12843079e+02,-1.39668296e+02,-6.01626070e+01,-3.51688015e+01]
jax.scipy.special.digamma(N)
TypeError: digamma does not accept dtype complex64. Accepted dtypes are subtypes of floating.
我正在尝试使用 jax.scipy.special.digamma 计算复数中的 digamma,但是,即使这个包的文档说它可能很复杂,它仍然给我这个错误 这是文档所说的:
参数: z (array_like) – 实数或复数参数。
知道如何解决这个问题吗?或者有没有其他替代方法,例如其他库或其他包,可以让我使用复数来计算 digamma 函数!?
解决方法
我遇到了类似的问题。我就是这样解决的。漫长的道路。我知道有人会给出一个简洁的方法。
从定义 psi function
我使用了 here 中的 Gamma 函数。确保输出在 JAX 中,否则您将无法使用 GRAD;
import jax.numpy as jnp
from jax import grad
def gamma_func_numeric(z):
g = 7
z -= 1
x = lanczos_coef[0]
for i in range(1,g+2):
x += lanczos_coef[i]/(z+i)
t = z + g + 0.5
return jnp.sqrt(2*jnp.pi)*jnp.power(t,(z+0.5))*jnp.exp(-t)*x
eta is a complex number
def psi_numeric(eta):
gamma_prime = grad(gamma_func_numeric,holomorphic=True)(eta)
gamma = gamma_func_numeric(eta)
return gamma_prime/gamma
让我们比较一下结果:
scipy.special.psi(1.+2j)=(0.7145915153739777+1.3208072826422304j)
psi_numeric(1.+2j) = DeviceArray(0.7145916+1.3208078j,dtype=complex64)