线性回归的梯度下降不适用于cpu.arff数据集

问题描述

我的线性下降的梯度下降实现如下所示:

private static void byGradientDescent(double[][] datax,double[][] dataY){
    double alpha = 0.05;        
    int m = datax[0].length ;//variable number;
    int n = datax.length;    //sample number
    System.out.println(n+"\t"+m);
    double[] thetas = new double[m + 1];        //thetas[m] is the intercept
    for(int i = 0;i<thetas.length;i++)          //initialize
        thetas[i] = 0.5;
    double[] derivatives = new double[m + 1];    
    
    boolean flag = false;
    double lastRSS = 0;
    do {
        for(int i = 0;i<derivatives.length;i++)
            derivatives[i] = 0;         
        double RSS = 0;
        for (int i = 0; i < n; i++) {   //calculate derivatives
            double diff =  thetas[m] ;  //difference
            for(int j = 0;j<m;j++)
                diff +=  thetas[j] * datax[i][j];
            diff = diff - dataY[i][0];
            RSS += diff*diff;
            
            derivatives[m] += diff / n;
            for(int j = 0;j<m;j++){
                derivatives[j] = diff * datax[i][j] /n;
            }
        }           
        for(int i = 0;i<thetas.length;i++)  // update thetas
            thetas[i] = thetas[i] - (alpha * derivatives[i]);           
        System.out.println(lastRSS - RSS);
        lastRSS = RSS;
        System.out.println(Arrays.toString(thetas));
        flag = false;
        for(double derivative : derivatives)    // termination condition
            flag = flag || (Math.abs(derivative)>0.01);
    } while (flag);
}

但是,它仅适用于非常小的示例,但在Weka的cpu.arff数据集上失败。我为alpha尝试了不同的值,但仍然无法正常工作。例如,当alpha设置为0.00000005时,循环不会停止,而当其设置为0.0005时,将返回[Infinity,Infinity,NaN,NaN,NaN,NaN,Infinity,Infinity]。我不确定自己的实现是否有问题或使用方式有误。

下面是我用来求解数据集的代码

public class LinearRegressionByLS {
public static void main(String arg[]) throws Exception {
    double[][] X= {{5,10},{1,1,9},2,3,12}};
    double[][] Y = {{1},{2},{3},{5}};
    byGradientDescent(X,Y);

    DataSource source = new DataSource("D:\\Program Files\\Weka-3-8-4\\data\\cpu.arff");
    Instances instances = source.getDataSet();
    instances.setClassIndex(instances.numAttributes() - 1);

    int n = instances.numInstances();
    int m = instances.numAttributes() - 1;
    double[][] datax = new double[n][m];
    double[][] dataY = new double[n][1];
    for (int i = 0; i < n; i++) {
        double[] values = instances.instance(i).todoubleArray();
        for (int j = 0; j < m; j++)
            datax[i][j] = values[j];
        dataY[i][0] = values[m];
    }
    
    byGradientDescent(datax,dataY);
}

}

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)