问题描述
我为神经网络编写了代码,但是当我训练我的网络时,它不会产生所需的输出(网络未学习,有时训练时为 NaN 值)。我的反向传播算法有什么问题?下面附上我是如何分别推导出权重和偏置梯度的公式的。完整代码可以在 here 中找到。
public double[][] predict(double[][] input) {
if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
throw new IllegalArgumentException("Prediction Error!");
}
this.activations.set(0,input);
for(int i = 1; i < this.activations.size(); i++) {
this.activations.set(i,this.sigmoid(this.add(this.multiply(this.weights.get(i-1),this.activations.get(i-1)),this.biases.get(i-1))));
}
return this.activations.get(this.n-1);
}
public void train(double[][] input,double[][] target) {
//calculate activations
this.predict(input);
//calculate weight gradients
for(int l = 0; l < this.weightGradients.size(); L++) {
for(int i = 0; i < this.weightGradients.get(l).length; i++) {
for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l,i,j,target);
}
}
}
//calculated bias gradients
for(int l = 0; l < this.biasGradients.size(); L++) {
for(int i = 0; i < this.biasGradients.get(l).length; i++) {
for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
this.biasGradients.get(l)[i][j] = this.gradientOfBias(l,target);
}
}
}
//apply gradient
for(int i = 0; i < this.weights.size(); i++) {
this.weights.set(i,this.subtract(this.weights.get(i),this.weightGradients.get(i)));
}
for(int i = 0; i < this.biases.size(); i++) {
this.biases.set(i,this.subtract(this.biases.get(i),this.biasGradients.get(i)));
}
}
private double gradientOfWeight(int l,int i,int j,double[][] t) { //when referring to A,use l+1 because A[0] is input vector,n-1 because n starts at 1
double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
if((l + 1) < (this.n - 1)) {
double sum = 0.0;
for(int k = 0; k < this.weights.get(l + 1).length; k++) {
sum += this.gradientOfWeight(l + 1,k,t)*this.weights.get(l + 1)[k][i];
}
return ((z * sum) / this.activations.get(l + 1)[i][0]);
} else if((l + 1) == (this.n - 1)) {
return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
}
throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}
解决方法
这个问题所涉及的数学量加上缺乏数据/代码复制,几乎不可能回答“我的 NaN 在哪里”的原始问题。
相反,我建议您将这个问题重新考虑为一个更简单的问题,“我如何知道像 NaN 这样的值在我的代码中来自哪里”。
如果您可以在 IDE 中运行您的代码,它们中的大多数将支持条件断点。即,只要变量达到某个值就会暂停代码的断点。在您的情况下,我建议您在首选 IDE 中运行您的代码,并使用条件断点检测值是否为 NaN。
您可以在这篇 SO post 中阅读更多关于如何设置它的信息,其中 NaN 双重检查的主题在此线程中很好地提到: Eclipse Debugger doesn't stop at conditional breakpoint
另一个后续考虑是考虑您需要将这些断点放在哪里。简短的回答是将它们放在计算双精度值的任何地方,因为这些计算中的任何一个都可能引入 NaN。
为此,我提出以下两个建议:
首先,在您当前计算双精度数的位置放置一个断点,以查看 NaN 是否来自这些计算。那将是这两个变量:
double z = ...
double sum = ...
其次,重构您对 gradientOfWeight 的调用以返回到一个临时变量,然后在这些临时计算上放置一个类似的断点。
所以代替
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l,i,j,target);
你会:
double interrimComputationToListenForNaNon = this.gradientOfWeight(l,target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;
拥有这些中间变量更方便,可以为您提供一种简单的方法来监视计算,而无需以任何显着的方式更改调用。可能有一种更聪明的方法来做到这一点,而无需中间变量,但这种方法似乎最容易监控和解释。
,您看到的 NaN 是由于下溢,您需要使用 BigDecimal 类而不是 double 以获得更高的精度。请参阅这些以更好地理解 bigdecimal class java sample use,BigDecimal API Reference