非传递子类化与numpy和jax

问题描述

我的问题很简单:

>>> isinstance(x,jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray,numpy.ndarray)
True
>>> isinstance(x,numpy.ndarray)
False

现在我将漫步,以便SE接受我的合理问题。

解决方法

之所以如此,是因为jax.numpy.ndarray用元类覆盖了实例检查:

class _ArrayMeta(type(np.ndarray)):  # type: ignore
  """Metaclass for overriding ndarray isinstance checks."""

  def __instancecheck__(self,instance):
    try:
      return isinstance(instance.aval,_arraylike_types)
    except AttributeError:
      return isinstance(instance,_arraylike_types)

class ndarray(np.ndarray,metaclass=_ArrayMeta):
  dtype: np.dtype
  shape: Tuple[int,...]
  size: int

  def __init__(shape,dtype=None,buffer=None,offset=0,strides=None,order=None):
    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                    " Use jax.numpy.array,or jax.numpy.zeros instead.")

view source

您的代码返回其结果的原因是因为您有一个x值,它不是numpy.ndarray的实例,但是此__instancecheck__方法为此返回true。

为什么在JAX中使用这种替代方法?好吧,出于JIT编译,自动分化和其他转换的目的,JAX使用称为 tracers 的替代对象,这些替代对象实际上看起来像是一个数组,但实际上看起来像是一个数组。实例检查的这种覆盖是JAX进行此类跟踪工作的技巧之一。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...