Weka 中的单实例分类MultilayerPerceptron

问题描述

我使用 weka.jar 版本 3.6.10 训练并创建了一个 MultilayerPerceptron 模型。我将模型文件保存到我的计算机上,现在我想用它来对 Java 代码中的单个实例进行分类。我想获得对属性“class”的预测。我找到了答案 here 并且我将值更改为我需要的值。我做的是以下内容

import weka.classifiers.Classifier;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.SerializationHelper;
public class JavaApplication {

    public static void main(String[] args) {
        JavaApplication q = new JavaApplication();
        double result = q.classify(-1.18,12.76,1.7297841);
        System.out.println(result);
    }

    private Instance inst_co;

    public double classify(double x,double y,double z)  {

        // Create attributes to be used with classifiers
        // Test the model
        double result = -1;
        try {

            FastVector attributeList = new FastVector();

            Attribute x_acc= new Attribute("x_acc");
            Attribute y_acc= new Attribute("y_acc");
            Attribute z_acc= new Attribute("z_acc");

            FastVector classVal = new FastVector();
            classVal.addElement("Walking");
            classVal.addElement("Jogging");
            classVal.addElement("Downstairs");
            classVal.addElement("Sitting");
            classVal.addElement("Upstairs");


            attributeList.addElement(x_acc);
            attributeList.addElement(y_acc);
            attributeList.addElement(z_acc);
            attributeList.addElement(new Attribute("@@class@@",classVal));

            Instances data = new Instances("TestInstances",attributeList,0);


            // Create instances for each pollutant with attribute values latitude,// longitude and pollutant itself
            inst_co = new SparseInstance(data.numAttributes());
            data.add(inst_co);

            // Set instance's values for the attributes "latitude","longitude",and
            // "pollutant concentration"
            inst_co.setValue(x_acc,x);
            inst_co.setValue(y_acc,y);
            inst_co.setValue(z_acc,z);
            // inst_co.setMissing(cluster);
            

            // load classifier from file
           Classifier cls_co = (MultilayerPerceptron) SerializationHelper
                   .read("/Users/ALL-TECH/Desktop/Sensors application/FewDataGenerated/model.model");

            result = cls_co.classifyInstance(inst_co);
        } catch (Exception e) {
            // Todo Auto-generated catch block
            e.printstacktrace();
        }
        return result;
    }
}

我的 arff 文件如下所示:

@RELATION fewOfDataCsv

@ATTRIBUTE x_acc  NUMERIC
@ATTRIBUTE y_acc  NUMERIC
@ATTRIBUTE z_acc  NUMERIC
@ATTRIBUTE class  {Upstairs,Downstairs,Walking,Jogging,Sitting}

@DATA
-1.18,1.7297841,Upstairs
0.93,10.99,0.08172209,Upstairs
0.08,11.35,0.46309182,Upstairs
1.88,9.47,3.405087,Walking
0.89,9.38,3.3778462,Walking
1.38,11.54,3.336985,Walking
2.83,3.68,-3.255263,Jogging
-1.8,2.45,7.082581,Jogging
16.63,9.89,-1.56634,Jogging
12.53,1.88,-6.3198414,Jogging
7.46,2.3,6.4,Sitting
7.5,6.44,Sitting
7.46,6.47,Sitting
-1.23,8.28,0.040861044,Downstairs
-1.92,6.28,1.1441092,Downstairs
-1.73,5.75,2.152015,Downstairs

结果(我真的不知道那个数字来自哪里):

run:
3.0
BUILD SUCCESSFUL (total time: 1 second)

我的代码中缺少什么?如果有人能帮忙,我将不胜感激。

解决方法

在您的 main 方法中,您正在调用 JavaApplication.classify 方法,该方法从分类器(双精度)返回分类。反过来,您通过 System.out.println 输出到 stdout。为什么你对数字来自哪里感到困惑?

Classifier.classifyInstance 方法返回一个双精度值。在回归算法的情况下,这是预测的数值,在分类算法的情况下,这是预测类别标签的基于 0 的索引。在您的情况下,数字 3.0 对应于 Jogging

,

这里重写了您的代码以解决以下缺点:

  • 每次进行预测时都要重新创建数据集结构
  • 每次进行预测时都加载模型
  • 使用 weka.core.SparseInstance 没有意义,因为您只有三个属性,而且您似乎总是提供它们

从 Weka Explorer 或通过命令行保存模型时,Weka 不仅将 weka.classifiers.Classifier 对象存储在该文件中,还存储训练数据的标头(也称为结构)。您可以简单地将该标头用于您正在构建的 weka.core.Instance 对象的结构。

以下类希望获取模型的路径作为第一个参数,然后在构造函数中从该文件加载模型和标头。

我还为标题添加了一个 get 方法,以允许将分类标签索引(Classifier.classifyInstance 在名义类属性的情况下返回)转换为实际的类标签字符串。

public class JavaApplication {

  private Instances header;

  private Classifier model;

  public JavaApplication(String modelPath) throws Exception {
    Object[] objects = SerializationHelper.readAll(modelPath);
    model = (Classifier) objects[0];
    header = (Instances) objects[1];
  }

  public Instances getHeader() {
    return header;
  }

  public int classify(double x,double y,double z)  {
    int result = -1;
    try {
      double[] values = new double[header.numAttributes()];
      values[0] = x;
      values[1] = y;
      values[2] = z;
      Instance inst = new DenseInstance(1.0,values);
      inst.setDataset(header);
      result = (int) model.classifyInstance(inst);
    }
    catch (Exception e) {
      System.err.println("Failed to classify instance!");
      e.printStackTrace();
    }
    return result;
  }

  public static void main(String[] args) throws Exception {
    JavaApplication q = new JavaApplication(args[0]);
    int labelIndex = q.classify(-1.18,12.76,1.7297841);
    System.out.println(labelIndex);
    System.out.println(q.getHeader().classAttribute().value(labelIndex));
  }
}