可以设置“功能化”功能来接受复杂的输入吗? “输入类型不支持ufunc'wrapper_module_0'”错误

问题描述

我正在学习如何使用sympy进行符号变量操作。现在,我有一个表达式,希望在一个优化方案中对其进行多次评估,因此我希望它可以非常快速地运行。 sympy documentation on Numeric Computation描述了执行此操作的几种方法:subs / evalf; lambdify; lambdify-numpy;充实Theano。

到目前为止,我已经lambdify上班了,但是对于我来说它似乎还不够快。进一步阅读该内容,lambdify似乎以python速度运行,而ufuncify可能会将代码重新编译为C代码,因此我现在正在研究此选项。

到目前为止,我已经能够对表达式进行词缀化,但是每当我将复杂的输入作为参数传递时,它都会引发错误。

我修改了this link中的ufuncify示例,以创建我的MWE。当我运行此命令时:

import sympy as sp
from sympy.utilities.autowrap import ufuncify
import numpy as np

# Create an example expression
a,b,c = sp.symbols('a,c')
expr = a + b*c

# Create a binary (compiled) function that broadcasts it's arguments
func = ufuncify((a,c),expr)
b = func(np.arange(5),2.0+1j,3.0)

print(b)

我收到此错误:TypeError:输入类型不支持ufunc'wrapper_module_0',并且根据强制转换规则“ safe”,不能将输入安全地强制转换为任何受支持的类型

如果我更改代码以删除第二个参数的虚部,那么它将运行良好:

import sympy as sp
from sympy.utilities.autowrap import ufuncify
import numpy as np

# Create an example expression
a,2.0,3.0)

print(b)

返回:[6. 7. 8. 9. 10。]

理想情况下,我想使用cython或f2py作为后端,因为我希望它们是最快的...但是我遇到类似的错误:

func = ufuncify((a,expr,backend='cython')

返回TypeError:参数'b_5226316'具有错误的类型(预期numpy.ndarray,变得复杂)

解决方法

可能还有另一个答案,但是我通常避免在sympy中封装一个复杂的值,因此您可以在arai对中设置每个值(对于实部和虚部)

然后,您还可以定义小表达式(a = ar + sp.I*ai),以简化代码编写过程。还要在要允许复杂值的位置预先选择。

代码将是:

import sympy as sp
from sympy.utilities.autowrap import ufuncify
import numpy as np

# Create an example expression
ar,ai,br,bi,cr,ci = sp.symbols('ar,ci',real=True)
a = ar + sp.I*ai
b = br + sp.I*bi
c = cr + sp.I*ci

expr = a + b*c

# Create a binary (compiled) function that broadcasts it's arguments
func = ufuncify((ar,ci),expr)
b = func(np.arange(5),2.0,1,3.0,0)

print(b)

我没有编译它,因为我没有安装东西,但是至少开始编译了,所以我很确定应该可以。

更新不编译是一个错误。 ufuncify函数目前不允许进行复杂的输入(在CodeGenerator中也被称为里程碑)。 ufuncify使用Codegen模块从给定表达式创建C代码。在上述情况下,情况如下:

double autofunc0(double ar,double ai,double br,double bi,double cr,double ci) {

   double autofunc0_result;
   autofunc0_result = I*ai + ar + (I*bi + br)*(I*ci + cr);
   return autofunc0_result;

}

,它目前不使用任何复杂的扩展名,尽管它们可以在C99中使用。我猜这将是Sympy代码生成器的不错扩展。 在文档中,他们还建议人们可以编写自己的本机代码。

1) If you are really concerned about speed or memory optimizations,you will probably get better results by working directly with the
  wrapper tools and the low level code.  However,the files generated
  by this utility may provide a useful starting point and reference
  code. Temporary files will be left intact if you supply the keyword
  tempdir="path/to/files/".

我想这将是一个替代选择,那就是使用生成的文件作为起点,然后使用标头实现复杂的值。但是,这不是一个令人满意的答案。

,
In [251]: a,b,c = symbols('a,c') 
     ...: expr = a + b*c 
     ...:                                                                                            

In [252]: f = lambdify((a,c),expr)                                                                

In [254]: print(f.__doc__)                                                                           
Created with lambdify. Signature:

func(a,c)

Expression:

a + b*c

Source code:

def _lambdifygenerated(a,c):
    return (a + b*c)


Imported modules:

对于这个简单的expr,经过lambdified的numpy代码看起来不错-充分利用了整个数组的编译运算符:

In [255]: f(np.arange(5),3.0)                                                                  
Out[255]: array([ 6.,7.,8.,9.,10.])

In [256]: timeit f(np.arange(5),3.0)                                                           
5.55 µs ± 16 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)

In [257]: timeit np.arange(5) + 2.0 * 3.0                                                            
5.04 µs ± 16.7 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)

如果我使用numba,速度会更快:

In [258]: import numba                                                                               

In [259]: @numba.njit 
     ...: def foo(a,c): 
     ...:     return a+b*c 
     ...:                                                                                            

In [260]: foo(np.arange(5),3.0)                                                                  
Out[260]: array([ 6.,10.])

In [261]: timeit foo(np.arange(5),3.0)                                                           
2.68 µs ± 69.1 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)
,

有点hacky,可能只适用于f2py,但我认为这是您想要做的:

import sympy as sp
from sympy.utilities.autowrap import ufuncify
import numpy as np

# This seems to be False by default and prevents the f2py codegen from
# creating complex variables in fortran
sp.utilities.codegen.COMPLEX_ALLOWED = True

# Create an example expression,mark them as Complex so that the codegen
# tool doesn't assume they're just floats. May not even be necessary.
a,c = sp.symbols('a,c',complex=True)
expr = a + b*c

# Create a binary (compiled) function that broadcasts its arguments
func = ufuncify((a,expr,backend="f2py")
# The f2py backend will want identically-sized arrays
b = func(np.arange(5),np.ones(5) * (2.0 + 1.0j),np.ones(5) * 3)

print(b)

您可以确认生成的fortran代码需要复杂的变量 如果您在tempdir="."上设置了ufuncify并打开了wrapped_code_0.f90文件。

!******************************************************************************
!*                      Code generated with sympy 1.6.2                       *
!*                                                                            *
!*              See http://www.sympy.org/ for more information.               *
!*                                                                            *
!*                      This file is part of 'autowrap'                       *
!******************************************************************************

subroutine autofunc(y_1038050,a_1038054,b_1038055,c_1038056,&
      m_1038051)
implicit none
INTEGER*4,intent(in) :: m_1038051
COMPLEX*16,intent(out),dimension(1:m_1038051) :: y_1038050
COMPLEX*16,intent(in),dimension(1:m_1038051) :: a_1038054
COMPLEX*16,dimension(1:m_1038051) :: b_1038055
COMPLEX*16,dimension(1:m_1038051) :: c_1038056
INTEGER*4 :: i_1038052

do i_1038052 = 1,m_1038051
   y_1038050(i_1038052) = a_1038054(i_1038052) + b_1038055(i_1038052)* &
         c_1038056(i_1038052)
end do

end subroutine

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...