Java 中的神经网络无法反向传播

问题描述

我为神经网络编写了代码,但是当我训练我的网络时,它不会产生所需的输出(网络未学习,有时训练时为 NaN 值)。我的反向传播算法有什么问题?下面附上我是如何分别推导出权重和偏置梯度的公式的。完整代码可以在 here 中找到。

public double[][] predict(double[][] input) {
    if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
        throw new IllegalArgumentException("Prediction Error!");
    }
    this.activations.set(0,input);
    for(int i = 1; i < this.activations.size(); i++) {
        this.activations.set(i,this.sigmoid(this.add(this.multiply(this.weights.get(i-1),this.activations.get(i-1)),this.biases.get(i-1))));
    }
    return this.activations.get(this.n-1);
}

public void train(double[][] input,double[][] target) {
    //calculate activations
    this.predict(input);
    //calculate weight gradients
    for(int l = 0; l < this.weightGradients.size(); L++) {
        for(int i = 0; i < this.weightGradients.get(l).length; i++) {
            for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
                this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l,i,j,target);
            }
        }
    }
    //calculated bias gradients
    for(int l = 0; l < this.biasGradients.size(); L++) {
        for(int i = 0; i < this.biasGradients.get(l).length; i++) {
            for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
                this.biasGradients.get(l)[i][j] = this.gradientOfBias(l,target);
            }
        }
    }
    //apply gradient
    for(int i = 0; i < this.weights.size(); i++) {
        this.weights.set(i,this.subtract(this.weights.get(i),this.weightGradients.get(i)));
    }
    for(int i = 0; i < this.biases.size(); i++) {
        this.biases.set(i,this.subtract(this.biases.get(i),this.biasGradients.get(i)));
    }
}

private double gradientOfWeight(int l,int i,int j,double[][] t) { //when referring to A,use l+1 because A[0] is input vector,n-1 because n starts at 1
    double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
    if((l + 1) < (this.n - 1)) {
        double sum = 0.0;
        for(int k = 0; k < this.weights.get(l + 1).length; k++) {
            sum += this.gradientOfWeight(l + 1,k,t)*this.weights.get(l + 1)[k][i];
        }
        return ((z * sum) / this.activations.get(l + 1)[i][0]);
    } else if((l + 1) == (this.n - 1)) {
        return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
    }
    throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}

Math to calculate gradient

解决方法

这个问题所涉及的数学量加上缺乏数据/代码复制,几乎不可能回答“我的 NaN 在哪里”的原始问题。

相反,我建议您将这个问题重新考虑为一个更简单的问题,“我如何知道像 NaN 这样的值在我的代码中来自哪里”。

如果您可以在 IDE 中运行您的代码,它们中的大多数将支持条件断点。即,只要变量达到某个值就会暂停代码的断点。在您的情况下,我建议您在首选 IDE 中运行您的代码,并使用条件断点检测值是否为 NaN。

您可以在这篇 SO post 中阅读更多关于如何设置它的信息,其中 NaN 双重检查的主题在此线程中很好地提到: Eclipse Debugger doesn't stop at conditional breakpoint

另一个后续考虑是考虑您需要将这些断​​点放在哪里。简短的回答是将它们放在计算双精度值的任何地方,因为这些计算中的任何一个都可能引入 NaN。

为此,我提出以下两个建议:

首先,在您当前计算双精度数的位置放置一个断点,以查看 NaN 是否来自这些计算。那将是这两个变量:

double z = ...

double sum = ...

其次,重构您对 gradientOfWeight 的调用以返回到一个临时变量,然后在这些临时计算上放置一个类似的断点。

所以代替

this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l,i,j,target);

你会:

double interrimComputationToListenForNaNon = this.gradientOfWeight(l,target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;

拥有这些中间变量更方便,可以为您提供一种简单的方法来监视计算,而无需以任何显着的方式更改调用。可能有一种更聪明的方法来做到这一点,而无需中间变量,但这种方法似乎最容易监控和解释。

,

您看到的 NaN 是由于下溢,您需要使用 BigDecimal 类而不是 double 以获得更高的精度。请参阅这些以更好地理解 bigdecimal class java sample use,BigDecimal API Reference