问题描述
我正在尝试使用 JAX 库及其小型神经网络子模块“Stax”来实现和训练神经网络。由于这个库没有实现二元交叉熵,我自己写的:
def binary_cross_entropy(y_hat,y):
bce = y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)
return jnp.mean(-bce)
我实现了一个简单的神经网络并在 MNIST 上对其进行了训练,然后开始对我得到的一些结果产生怀疑。所以我在 Keras 中实现了相同的设置,我立即得到了非常不同的结果!在相同数据上以相同方式训练的相同模型在 Keras 中获得了 90% 的训练准确率,而不是在 JAX 中的大约 50%。最终,我将问题的一部分追溯到我对交叉熵的幼稚实现,据说它在数值上是不稳定的。根据我找到的 this 帖子和 this 代码,我编写了以下新版本:
def binary_cross_entropy_stable(y_hat,y):
y_hat = jnp.clip(y_hat,0.000001,0.9999999)
logits = jnp.log(y_hat/(1 - y_hat))
max_logit = jnp.clip(logits,None)
bces = logits - logits * y + max_logit + jnp.log(jnp.exp(-max_logit) + jnp.exp(-logits - max_logit))
return jnp.mean(bces)
这样效果会好一些。现在我的 JAX 实现达到了 80% 的训练准确率,但这仍然比 Keras 的 90% 低很多。我想知道的是发生了什么?为什么我的两个实现方式不同?
下面,我将我的两个实现压缩为一个脚本。在这个脚本中,我在 JAX 和 Keras 中实现了相同的模型。我用相同的权重初始化它们,并使用全批次梯度下降对来自 MNIST 的 1000 个数据点进行 10 步训练,每个模型的相同数据。 JAX 以 80% 的训练准确率完成,而 Keras 以 90% 完成。具体来说,我得到这个输出:
Initial Keras accuracy: 0.4350000023841858
Initial JAX accuracy: 0.435
Final JAX accuracy: 0.792
Final Keras accuracy: 0.9089999794960022
JAX accuracy (Keras weights): 0.909
Keras accuracy (JAX weights): 0.7919999957084656
实际上,当我稍微改变条件时(使用不同的随机初始权重或不同的训练集),有时我会得到 50% 的 JAX 准确度和 90% 的 Keras 准确度。
我在最后交换了权重以验证从训练中获得的权重确实是问题所在,而不是与网络预测的实际计算或我计算准确性的方式有关。
代码:
import numpy as np
import jax
from jax import jit,grad
from jax.experimental import stax,optimizers
import jax.numpy as jnp
import keras
import keras.datasets.mnist
def binary_cross_entropy(y_hat,y):
bce = y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)
return jnp.mean(-bce)
def binary_cross_entropy_stable(y_hat,None)
bces = logits - logits * y + max_logit + jnp.log(jnp.exp(-max_logit) + jnp.exp(-logits - max_logit))
return jnp.mean(bces)
def binary_accuracy(y_hat,y):
return jnp.mean((y_hat >= 1/2) == (y >= 1/2))
########################################
# #
# Create dataset #
# #
########################################
input_dimension = 784
(x_train,y_train),(x_test,y_test) = keras.datasets.mnist.load_data(path="mnist.npz")
xs = np.concatenate([x_train,x_test])
xs = xs.reshape((70000,784))
ys = np.concatenate([y_train,y_test])
ys = (ys >= 5).astype(np.float32)
ys = ys.reshape((70000,1))
train_xs = xs[:1000]
train_ys = ys[:1000]
########################################
# #
# Create JAX model #
# #
########################################
jax_initializer,jax_model = stax.serial(
stax.Dense(1000),stax.Relu,stax.Dense(1),stax.Sigmoid
)
rng_key = jax.random.PRNGKey(0)
_,initial_jax_weights = jax_initializer(rng_key,(1,input_dimension))
########################################
# #
# Create Keras model #
# #
########################################
initial_keras_weights = [*initial_jax_weights[0],*initial_jax_weights[2]]
keras_model = keras.Sequential([
keras.layers.Dense(1000,activation="relu"),keras.layers.Dense(1,activation="sigmoid")
])
keras_model.compile(
optimizer=keras.optimizers.SGD(learning_rate=0.01),loss=keras.losses.binary_crossentropy,metrics=["accuracy"]
)
keras_model.build(input_shape=(1,input_dimension))
keras_model.set_weights(initial_keras_weights)
if __name__ == "__main__":
########################################
# #
# Compare untrained models #
# #
########################################
initial_keras_predictions = keras_model.predict(train_xs,verbose=0)
initial_jax_predictions = jax_model(initial_jax_weights,train_xs)
_,keras_initial_accuracy = keras_model.evaluate(train_xs,train_ys,verbose=0)
jax_initial_accuracy = binary_accuracy(jax_model(initial_jax_weights,train_xs),train_ys)
print("Initial Keras accuracy:",keras_initial_accuracy)
print("Initial JAX accuracy:",jax_initial_accuracy)
########################################
# #
# Train JAX model #
# #
########################################
L = jit(binary_cross_entropy_stable)
gradL = jit(grad(lambda w,x,y: L(jax_model(w,x),y)))
opt_init,opt_apply,get_params = optimizers.sgd(0.01)
network_state = opt_init(initial_jax_weights)
for _ in range(10):
wT = get_params(network_state)
gradient = gradL(wT,train_xs,train_ys)
network_state = opt_apply(
0,gradient,network_state
)
final_jax_weights = get_params(network_state)
final_jax_training_predictions = jax_model(final_jax_weights,train_xs)
final_jax_accuracy = binary_accuracy(final_jax_training_predictions,train_ys)
print("Final JAX accuracy:",final_jax_accuracy)
########################################
# #
# Train Keras model #
# #
########################################
for _ in range(10):
keras_model.fit(
train_xs,epochs=1,batch_size=1000,verbose=0
)
final_keras_loss,final_keras_accuracy = keras_model.evaluate(train_xs,verbose=0)
print("Final Keras accuracy:",final_keras_accuracy)
########################################
# #
# Swap weights #
# #
########################################
final_keras_weights = keras_model.get_weights()
final_keras_weights_in_jax_format = [
(final_keras_weights[0],final_keras_weights[1]),tuple(),(final_keras_weights[2],final_keras_weights[3]),tuple()
]
jax_accuracy_with_keras_weights = binary_accuracy(
jax_model(final_keras_weights_in_jax_format,train_ys
)
print("JAX accuracy (Keras weights):",jax_accuracy_with_keras_weights)
final_jax_weights_in_keras_format = [*final_jax_weights[0],*final_jax_weights[2]]
keras_model.set_weights(final_jax_weights_in_keras_format)
_,keras_accuracy_with_jax_weights = keras_model.evaluate(train_xs,verbose=0)
print("Keras accuracy (JAX weights):",keras_accuracy_with_jax_weights)
尝试将第 57 行的 PRNG 种子更改为 0
以外的值,以使用不同的初始权重运行实验。
解决方法
您的 binary_cross_entropy_stable
函数与 keras.binary_crossentropy
的输出不匹配;例如:
x = np.random.rand(10)
y = np.random.rand(10)
print(keras.losses.binary_crossentropy(x,y))
# tf.Tensor(0.8134677734043875,shape=(),dtype=float64)
print(binary_cross_entropy_stable(x,y))
# 0.9781515
如果你想完全复制模型,我会从这里开始。
您可以在此处查看 keras 损失函数的来源:keras/losses.py#L1765-L1810,其中实现的主要部分在此处:keras/backend.py#L4972-L5017
一个细节:似乎通过 sigmoid 激活函数,Keras 重新使用一些缓存的 logits 来计算二元交叉熵,同时避免有问题的值:keras/backend.py#L4988-L4997。我不确定如何使用 JAX 和 stax 轻松复制该行为。