问题描述
问题
我有以下函数(基于 scipy.integrate.quad
):
def simple_quad(func: Any,a: float,b: float,args: tuple = ()) -> float:
def strips(n: int):
for i in range(n):
x = a + (b - a) * i / n
yield func(x,*args) * 1 / n
return sum(strips(1000))
... 它基本上在一个值范围内评估 func
并使用固定宽度的条带来计算图形下的面积。或者,可以通过 func
元组将参数传递给 args
。
如您所见,我已经做了一些初始类型提示(实际上这是 scipy 的 .pyi 存根),但是我对 func 和 args 的类型如此松散不满意。我希望 mypy 保护我免受两件事的伤害:
-
func
是一个可调用的,它必须有一个float
的第一个位置参数,并返回一个float
,并且可以有可选的*args
- 即至少
f(x:float,...) -> float
- 我猜它也可以有 **kwargs(虽然不能有 required name-only params 或 x 以外的必需位置参数)
- 即至少
- 可选的位置
*args
到func
必须匹配 splatted args 元组的内容
示例
def cubic(x: float,c: float,d: float) -> float:
"Function to evaluate a cubic polynomial with coefficients"
return a + b * x + c * x ** 2 + d * x ** 3
simple_quad(cubic,a=1,b=10,args=(1,2,3,4)) # fine,passes args to a,b,c,d and int is compatible with float signature
simple_quad(cubic,b=10) # error from mypy as *args to `cubic` don't have default args so are all required
simple_quad(cubic,args=("a","b","c","d")) # arg length correct but type mismatch since cubic expects floats
x_squared: Callable[[float],float] = lambda x: x * x
simple_quad(x_squared,args=()) # should be fine as x_squared doesn't take any positional args other than x
def problematic(x: float,*,y: float) -> float: ... # can't pass kwargs y via simple_quad,so can't be integrated
我的尝试
对于 func
,我尝试使用协议和泛型:
class OneDimensionalFunction(Protocol,Generic[T]): #double inheritance,although maybe I can get by with a Metaclass for Generic
def __call__(self,x: float,*args: T) -> float: ...
...希望我能写
def simple_quad(func: OneDimensionalFunction[T],args: tuple[T] = ()) -> float:
simple_quad(cubic,1,10,(1,4)) # infer type requirement of args based on signature of func
# or
simple_quad[float,float,float](cubic,...) #pass in additional type info above and beyond the expected Callable[[x:float,...],float]
...我知道这有很多问题,如果例如我想将 lambda 作为 func 传入,则协议也不能很好地与 Callable 配合使用。
我将此 python 标记为 3.10,因为我认为新的 Parameter Specification Variables 可能会有所帮助,但我只看到了装饰器中使用的那些,所以我不确定如何在此处应用它们。让我知道你的想法
解决方法
对协议使用重载。这不是很好,但它使您尽可能接近真实的验证。
from typing import Protocol,Tuple,Any,overload,Union,Optional
class C1(Protocol):
def __call__(self,a: float,) -> float: ...
class C2(Protocol):
def __call__(self,b: float,) -> float: ...
class C3(Protocol):
def __call__(self,c: float) -> float: ...
ET = Tuple[()]
T1 = Tuple[float]
T2 = Tuple[float,float]
T3 = Tuple[float,float,float]
@overload
def quad(func: C1,args: Union[ET,T1] = ()) -> float: ...
@overload
def quad(func: C2,T2] = ()) -> float: ...
@overload
def quad(func: C3,T3] = ()) -> float: ...
def quad(func: Any,args: tuple = ()) -> float:
return 0
def cubic(a: float,c: float,*,s: Optional[str] = None) -> float:
return 0
quad(cubic,1)