为什么我的神经网络使用 Jax 不收敛?

问题描述

我正在学习 Jax,但我遇到了一个奇怪的问题。 如果我使用如下代码

import numpy as np
import jax.numpy as jnp
from jax import grad,value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv,Dense,MaxPool,Relu,Flatten,Logsoftmax # neural network layers
import matplotlib.pyplot as plt # visualization

net_init,net_apply = stax.serial(
    Dense(40),Dense(40),Dense(1)
)
rng = random.PRNGKey(0)
in_shape = (-1,1,)
out_shape,params = net_init(rng,in_shape)

def loss(params,X,Y):
    predictions = net_apply(params,X)
    return jnp.mean((Y - predictions)**2)

@jit
def step(i,opt_state,x1,y1):
    p = get_params(opt_state)
    val,g = value_and_grad(loss)(p,y1)
    return val,opt_update(i,g,opt_state)

opt_init,opt_update,get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)

val_his = []
for i in range(1000):
    val,opt_state = step(i,xrange_inputs,targets)
    val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)

xrange_inputs = jnp.linspace(-5,5,100).reshape((100,1)) # (k,1)
targets = jnp.cos(xrange_inputs)
predictions = vmap(partial(net_apply,params))(xrange_inputs)
losses = vmap(partial(loss,params))(xrange_inputs,targets) # per-input loss

plt.plot(xrange_inputs,predictions,label='prediction')
plt.plot(xrange_inputs,losses,label='loss')
plt.plot(xrange_inputs,targets,label='target')
plt.legend()

神经网络可以很好地逼近函数cos(x)

但是如果我自己重写神经网络部分如下

import numpy as np
import jax.numpy as jnp
from jax import grad,Logsoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
import numpy as np
from jax.experimental import optimizers
from jax.tree_util import tree_multimap

def initialize_NN(layers,key):        
    params = []
    num_layers = len(layers)
    keys = random.split(key,len(layers))
    a = jnp.sqrt(0.1)
    #params.append(a)
    for l in range(0,num_layers-1):
        W = xavier_init((layers[l],layers[l+1]),keys[l])
        b = jnp.zeros((layers[l+1],),dtype=np.float32)
        params.append((W,b))
    return params

def xavier_init(size,key):
    in_dim = size[0]
    out_dim = size[1]      
    xavier_stddev = jnp.sqrt(2/(in_dim + out_dim))
    return random.truncated_normal(key,-2,2,shape=(out_dim,in_dim),dtype=np.float32)*xavier_stddev
    
def net_apply(params,X):
    num_layers = len(params)
    #a = params[0]
    for l in range(0,num_layers-1):
        W,b = params[l]
        X = jnp.maximum(0,jnp.add(jnp.dot(X,W.T),b))
    W,b = params[-1]
    Y = jnp.dot(X,W.T)+ b
    Y = jnp.squeeze(Y)
    return Y
    
def loss(params,X)
    return jnp.mean((Y - predictions)**2)

key = random.PRNGKey(1)
layers = [1,40,1]
params = initialize_NN(layers,key)

@jit
def step(i,get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)

xrange_inputs = jnp.linspace(-5,1)
targets = jnp.cos(xrange_inputs)

val_his = []
for i in range(1000):
    val,targets)
    val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)

predictions = vmap(partial(net_apply,label='target')
plt.legend()

我的神经网络总是会收敛到一个常数,它似乎被一个局部最小值所困。但是同样的神经网络和第一部分一样工作得很好。我真的很困惑。

唯一的区别应该是初始化、神经网络部分和参数 params 的设置。我尝试了不同的初始化,没有任何区别。不知道是不是因为优化params的设置不对,导致无法收敛。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...