如何计算复数中的 digamma 函数以便在 Tensorflow 中使用此函数接受输入作为张量?

问题描述

N= [-7.12843079e+02,-1.39668296e+02,-6.01626070e+01,-3.51688015e+01]
jax.scipy.special.digamma(N)
TypeError: digamma does not accept dtype complex64. Accepted dtypes are subtypes of floating.

我正在尝试使用 jax.scipy.special.digamma 计算复数中的 digamma,但是,即使这个包的文档说它可能很复杂,它仍然给我这个错误 这是文档所说的:

参数: z (array_like) – 实数或复数参数。

知道如何解决这个问题吗?或者有没有其他替代方法,例如其他库或其他包,可以让我使用复数来计算 digamma 函数!?

解决方法

我遇到了类似的问题。我就是这样解决的。漫长的道路。我知道有人会给出一个简洁的方法。

从定义 psi function

我使用了 here 中的 Gamma 函数。确保输出在 JAX 中,否则您将无法使用 GRAD;

import jax.numpy as jnp
from jax import grad

def gamma_func_numeric(z):
    g = 7
    z -= 1
    x = lanczos_coef[0]

    for i in range(1,g+2):
        x +=   lanczos_coef[i]/(z+i)    
        t = z + g  + 0.5
   return jnp.sqrt(2*jnp.pi)*jnp.power(t,(z+0.5))*jnp.exp(-t)*x

eta is a complex number

def psi_numeric(eta):
    gamma_prime = grad(gamma_func_numeric,holomorphic=True)(eta)
    gamma = gamma_func_numeric(eta)
    return gamma_prime/gamma

让我们比较一下结果:

scipy.special.psi(1.+2j)=(0.7145915153739777+1.3208072826422304j)

psi_numeric(1.+2j) = DeviceArray(0.7145916+1.3208078j,dtype=complex64)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...