Numba 是否支持内置的 python 函数,例如`setitem`

问题描述

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:

尝试使用以下命令 SG_Z 将矩阵 SG_Z[SG_Z < Threshold] = 0 中小于给定阈值的所有元素设置为零时遇到此错误

此命令正在使用 Numba @jit 并行化的函数中使用。由于这一行的存在,函数没有运行。

解决方法

您对 SG_Z 说得不多,但我怀疑它是 2d(或更高)。 numba 的多维索引能力有限(与 numpy 相比)

In [133]: arr = np.random.rand(3,4)
In [134]: arr
Out[134]: 
array([[0.8466427,0.37340328,0.07712635,0.34466743],[0.86591184,0.32048868,0.1260246,0.9811717 ],[0.28948191,0.32099879,0.54819722,0.78863841]])
In [135]: arr<.5
Out[135]: 
array([[False,True,True],[False,False],[ True,False,False]])
In [136]: arr[arr<.5]
Out[136]: 
array([0.37340328,0.34466743,0.28948191,0.32099879])

numba

In [137]: import numba
In [138]: @numba.njit
     ...: def foo(arr,thresh):
     ...:     arr[arr<.5]=0
     ...:     return arr
     ...: 
In [139]: foo(arr,.5)
Traceback (most recent call last):
  File "<ipython-input-139-33ea2fda1ea2>",line 1,in <module>
    foo(arr,.5)
  File "/usr/local/lib/python3.8/dist-packages/numba/core/dispatcher.py",line 420,in _compile_for_args
    error_rewrite(e,'typing')
  File "/usr/local/lib/python3.8/dist-packages/numba/core/dispatcher.py",line 361,in error_rewrite
    raise e.with_traceback(None)
TypingError: No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(float64,2d,C),array(bool,Literal[int](0))
 
There are 16 candidate implementations:
  - Of which 14 did not match due to:
  Overload of function 'setitem': File: <numerous>: Line N/A.
    With argument(s): '(array(float64,int64)':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'SetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 171.
    With argument(s): '(array(float64,int64)':
   Rejected as the implementation raised a specific error:
     TypeError: unsupported array index type array(bool,C) in [array(bool,C)]
  raised from /usr/local/lib/python3.8/dist-packages/numba/core/typing/arraydecl.py:68

During: typing of setitem at <ipython-input-138-6861f217f595> (3)

通常情况下,setitem 没有丢失; numba 一直这样做。对于这种特定的参数组合,它是 setitem

如果我首先拆散数组,它确实有效。

In [140]: foo(arr.ravel(),.5)
Out[140]: 
array([0.8466427,0.,0.86591184,0.9811717,0.78863841])

但是在 numba 中,我们不需要害怕迭代,因此对于 2d 输入,我们可以对行进行迭代:

In [148]: @numba.njit
     ...: def foo(arr,thresh):
     ...:     for i in arr:
     ...:         i[i<thresh] = 0
     ...:     return arr
     ...: 
In [149]: foo(arr,.5)
Out[149]: 
array([[0.8466427,0.        ],[0.,0.78863841]])

可能有更通用的编写方法和提供签名的方法,但这应该可以提供一些有关如何解决此问题的想法。