Numpy 分段意外结果

问题描述

我正在尝试使用 numpy.piecewise 定义一个非连续函数

import numpy as np

var = np.array([0.2,10])
supp = lambda x: -np.sqrt(1 - 1/x**2)
inf = lambda x: 1j*np.sqrt(1/x**2 - 1)

def q1(eta_B):
    return np.piecewise(eta_B,[eta_B > 1,eta_B <= 1],[supp,inf])

使用此代码,这是我获得的:

>>> supp(var)
array([        nan,-0.99498744])
>>> inf(var)
array([ 0.+4.89897949j,nan       +nanj])
>>> q1(var)
array([ 0.,-0.99498744])

我不明白这里发生了什么;我希望 q1(var) 返回 array([0.+4.89897949j,-0.99498744])

有什么提示吗?

解决方法

当您运行 q1(var) 时,您没有收到以下警告吗?

Warning (from warnings module):
  File "C:\Users\Chinchi\AppData\Roaming\Python\Python38\site-packages\numpy\lib\function_base.py",line 616
    y[cond] = func(vals,*args,**kw)
ComplexWarning: Casting complex values to real discards the imaginary part

查看 np.piecewise 的文档,我们看到

输出与 x 相同的形状和类型

由于您从未为 var 定义类型并且它的所有值都适合浮点数,所以这就是它获得的类型。

>>> var.dtype
dtype('float64')
>>> var = np.array([0.2+0j,10])
>>> var.dtype
dtype('complex128')
>>> q1(var)
array([ 0.        +4.89897949j,-0.99498744-0.j        ])

解决此问题的最佳方法是在函数内将输入转换为复杂类型,这样无论您传递的是 int/float/complex 还是不必自己担心转换,它都可以工作。

np.piecewise(eta_B.astype(np.complex128),[eta_B > 1,eta_B <= 1],[supp,inf])