问题描述
我的问题很简单:
>>> isinstance(x,jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray,numpy.ndarray)
True
>>> isinstance(x,numpy.ndarray)
False
?
现在我将漫步,以便SE接受我的合理问题。
解决方法
之所以如此,是因为jax.numpy.ndarray
用元类覆盖了实例检查:
class _ArrayMeta(type(np.ndarray)): # type: ignore
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self,instance):
try:
return isinstance(instance.aval,_arraylike_types)
except AttributeError:
return isinstance(instance,_arraylike_types)
class ndarray(np.ndarray,metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int,...]
size: int
def __init__(shape,dtype=None,buffer=None,offset=0,strides=None,order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array,or jax.numpy.zeros instead.")
您的代码返回其结果的原因是因为您有一个x
值,它不是numpy.ndarray
的实例,但是此__instancecheck__
方法为此返回true。
为什么在JAX中使用这种替代方法?好吧,出于JIT编译,自动分化和其他转换的目的,JAX使用称为 tracers 的替代对象,这些替代对象实际上看起来像是一个数组,但实际上看起来像是一个数组。实例检查的这种覆盖是JAX进行此类跟踪工作的技巧之一。