使用Jax的偏导数?

问题描述

我对 Jax 文档感到困惑,这是我想要做的:

def line(m,x,b):
  return m*x + b

grad(line)(1,2,3)

错误

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-48-d14b17620b30> in <module>()
      3 
----> 4 grad(line)(1,3)

FilteredStackTrace: TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating),but got int32. If you want to use integer-valued inputs,use vjp or set allow_int to True.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred,unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
6 frames
/usr/local/lib/python3.7/dist-packages/jax/api.py in _check_input_dtype_revderiv(name,holomorphic,allow_int,x)
    844   elif not allow_int and not (dtypes.issubdtype(aval.dtype,np.floating) or
    845                               dtypes.issubdtype(aval.dtype,np.complexfloating)):
--> 846     raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype that "
    847                     "is a sub-dtype of np.floating or np.complexfloating),"
    848                     f"but got {aval.dtype.name}. If you want to use integer-valued "

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating),use vjp or set allow_int to True.

我引用了官方的 tutorial 代码

import jax.numpy as jnp
from jax import grad,jit,vmap
from jax import random

key = random.PRNGKey(0)

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W,b,inputs):
    return sigmoid(jnp.dot(inputs,W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52,1.12,0.77],[0.88,-1.08,0.15],[0.52,0.06,-1.30],[0.74,-2.49,1.39]])
targets = jnp.array([True,True,False,True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W,b):
    preds = predict(W,inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key,W_key,b_key = random.split(key,3)
W = random.normal(W_key,(3,))
b = random.normal(b_key,())

W_grad = grad(loss,argnums=0)(W,b)
print('W_grad',W_grad)

结果:

W_grad [-0.16965576 -0.8774648  -1.4901345 ]

在这里做错了什么?我认为 key 正在以某种重要的方式被使用,但我不知道为什么/如何需要它。要回答这个问题,请根据需要调整第一个块中的代码以消除错误

解决方法

Jax 告​​诉你它不喜欢整数。 grad(line)(1.,2.,3.)(使用浮动)修复了问题。

,

我认为这里的错误很明显:

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating),but got int32. If you want to use integer-valued inputs,use vjp or set allow_int to True.

要将 grad(line)(1,2,3)Int32 一起使用,请将其更改为 grad(line,allow_int=True)(1,3)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...