问题描述
我正在尝试使用 numpy.piecewise
定义一个非连续函数:
import numpy as np
var = np.array([0.2,10])
supp = lambda x: -np.sqrt(1 - 1/x**2)
inf = lambda x: 1j*np.sqrt(1/x**2 - 1)
def q1(eta_B):
return np.piecewise(eta_B,[eta_B > 1,eta_B <= 1],[supp,inf])
使用此代码,这是我获得的:
>>> supp(var)
array([ nan,-0.99498744])
>>> inf(var)
array([ 0.+4.89897949j,nan +nanj])
>>> q1(var)
array([ 0.,-0.99498744])
我不明白这里发生了什么;我希望 q1(var)
返回 array([0.+4.89897949j,-0.99498744])
!
有什么提示吗?
解决方法
当您运行 q1(var)
时,您没有收到以下警告吗?
Warning (from warnings module):
File "C:\Users\Chinchi\AppData\Roaming\Python\Python38\site-packages\numpy\lib\function_base.py",line 616
y[cond] = func(vals,*args,**kw)
ComplexWarning: Casting complex values to real discards the imaginary part
查看 np.piecewise
的文档,我们看到
输出与 x 相同的形状和类型
由于您从未为 var
定义类型并且它的所有值都适合浮点数,所以这就是它获得的类型。
>>> var.dtype
dtype('float64')
>>> var = np.array([0.2+0j,10])
>>> var.dtype
dtype('complex128')
>>> q1(var)
array([ 0. +4.89897949j,-0.99498744-0.j ])
解决此问题的最佳方法是在函数内将输入转换为复杂类型,这样无论您传递的是 int/float/complex 还是不必自己担心转换,它都可以工作。
np.piecewise(eta_B.astype(np.complex128),[eta_B > 1,eta_B <= 1],[supp,inf])