问题描述
我正在学习 Jax,但我遇到了一个奇怪的问题。 如果我使用如下代码,
import numpy as np
import jax.numpy as jnp
from jax import grad,value_and_grad
from jax import vmap # for auto-vectorizing functions
from functools import partial # for use with vmap
from jax import jit # for compiling functions for speedup
from jax import random # stax initialization uses jax.random
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv,Dense,MaxPool,Relu,Flatten,Logsoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
net_init,net_apply = stax.serial(
Dense(40),Dense(40),Dense(1)
)
rng = random.PRNGKey(0)
in_shape = (-1,1,)
out_shape,params = net_init(rng,in_shape)
def loss(params,X,Y):
predictions = net_apply(params,X)
return jnp.mean((Y - predictions)**2)
@jit
def step(i,opt_state,x1,y1):
p = get_params(opt_state)
val,g = value_and_grad(loss)(p,y1)
return val,opt_update(i,g,opt_state)
opt_init,opt_update,get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)
val_his = []
for i in range(1000):
val,opt_state = step(i,xrange_inputs,targets)
val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)
xrange_inputs = jnp.linspace(-5,5,100).reshape((100,1)) # (k,1)
targets = jnp.cos(xrange_inputs)
predictions = vmap(partial(net_apply,params))(xrange_inputs)
losses = vmap(partial(loss,params))(xrange_inputs,targets) # per-input loss
plt.plot(xrange_inputs,predictions,label='prediction')
plt.plot(xrange_inputs,losses,label='loss')
plt.plot(xrange_inputs,targets,label='target')
plt.legend()
神经网络可以很好地逼近函数cos(x)
。
但是如果我自己重写神经网络部分如下
import numpy as np
import jax.numpy as jnp
from jax import grad,Logsoftmax # neural network layers
import matplotlib.pyplot as plt # visualization
import numpy as np
from jax.experimental import optimizers
from jax.tree_util import tree_multimap
def initialize_NN(layers,key):
params = []
num_layers = len(layers)
keys = random.split(key,len(layers))
a = jnp.sqrt(0.1)
#params.append(a)
for l in range(0,num_layers-1):
W = xavier_init((layers[l],layers[l+1]),keys[l])
b = jnp.zeros((layers[l+1],),dtype=np.float32)
params.append((W,b))
return params
def xavier_init(size,key):
in_dim = size[0]
out_dim = size[1]
xavier_stddev = jnp.sqrt(2/(in_dim + out_dim))
return random.truncated_normal(key,-2,2,shape=(out_dim,in_dim),dtype=np.float32)*xavier_stddev
def net_apply(params,X):
num_layers = len(params)
#a = params[0]
for l in range(0,num_layers-1):
W,b = params[l]
X = jnp.maximum(0,jnp.add(jnp.dot(X,W.T),b))
W,b = params[-1]
Y = jnp.dot(X,W.T)+ b
Y = jnp.squeeze(Y)
return Y
def loss(params,X)
return jnp.mean((Y - predictions)**2)
key = random.PRNGKey(1)
layers = [1,40,1]
params = initialize_NN(layers,key)
@jit
def step(i,get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(params)
xrange_inputs = jnp.linspace(-5,1)
targets = jnp.cos(xrange_inputs)
val_his = []
for i in range(1000):
val,targets)
val_his.append(val)
params = get_params(opt_state)
val_his = jnp.array(val_his)
predictions = vmap(partial(net_apply,label='target')
plt.legend()
我的神经网络总是会收敛到一个常数,它似乎被一个局部最小值所困。但是同样的神经网络和第一部分一样工作得很好。我真的很困惑。
唯一的区别应该是初始化、神经网络部分和参数 params
的设置。我尝试了不同的初始化,没有任何区别。不知道是不是因为优化params
的设置不对,导致无法收敛。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)