问题描述
我使用 weka.jar 版本 3.6.10 训练并创建了一个 MultilayerPerceptron 模型。我将模型文件保存到我的计算机上,现在我想用它来对 Java 代码中的单个实例进行分类。我想获得对属性“class”的预测。我找到了答案 here 并且我将值更改为我需要的值。我做的是以下内容:
import weka.classifiers.Classifier;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.SerializationHelper;
public class JavaApplication {
public static void main(String[] args) {
JavaApplication q = new JavaApplication();
double result = q.classify(-1.18,12.76,1.7297841);
System.out.println(result);
}
private Instance inst_co;
public double classify(double x,double y,double z) {
// Create attributes to be used with classifiers
// Test the model
double result = -1;
try {
FastVector attributeList = new FastVector();
Attribute x_acc= new Attribute("x_acc");
Attribute y_acc= new Attribute("y_acc");
Attribute z_acc= new Attribute("z_acc");
FastVector classVal = new FastVector();
classVal.addElement("Walking");
classVal.addElement("Jogging");
classVal.addElement("Downstairs");
classVal.addElement("Sitting");
classVal.addElement("Upstairs");
attributeList.addElement(x_acc);
attributeList.addElement(y_acc);
attributeList.addElement(z_acc);
attributeList.addElement(new Attribute("@@class@@",classVal));
Instances data = new Instances("TestInstances",attributeList,0);
// Create instances for each pollutant with attribute values latitude,// longitude and pollutant itself
inst_co = new SparseInstance(data.numAttributes());
data.add(inst_co);
// Set instance's values for the attributes "latitude","longitude",and
// "pollutant concentration"
inst_co.setValue(x_acc,x);
inst_co.setValue(y_acc,y);
inst_co.setValue(z_acc,z);
// inst_co.setMissing(cluster);
// load classifier from file
Classifier cls_co = (MultilayerPerceptron) SerializationHelper
.read("/Users/ALL-TECH/Desktop/Sensors application/FewDataGenerated/model.model");
result = cls_co.classifyInstance(inst_co);
} catch (Exception e) {
// Todo Auto-generated catch block
e.printstacktrace();
}
return result;
}
}
我的 arff 文件如下所示:
@RELATION fewOfDataCsv
@ATTRIBUTE x_acc NUMERIC
@ATTRIBUTE y_acc NUMERIC
@ATTRIBUTE z_acc NUMERIC
@ATTRIBUTE class {Upstairs,Downstairs,Walking,Jogging,Sitting}
@DATA
-1.18,1.7297841,Upstairs
0.93,10.99,0.08172209,Upstairs
0.08,11.35,0.46309182,Upstairs
1.88,9.47,3.405087,Walking
0.89,9.38,3.3778462,Walking
1.38,11.54,3.336985,Walking
2.83,3.68,-3.255263,Jogging
-1.8,2.45,7.082581,Jogging
16.63,9.89,-1.56634,Jogging
12.53,1.88,-6.3198414,Jogging
7.46,2.3,6.4,Sitting
7.5,6.44,Sitting
7.46,6.47,Sitting
-1.23,8.28,0.040861044,Downstairs
-1.92,6.28,1.1441092,Downstairs
-1.73,5.75,2.152015,Downstairs
结果(我真的不知道那个数字来自哪里):
run:
3.0
BUILD SUCCESSFUL (total time: 1 second)
我的代码中缺少什么?如果有人能帮忙,我将不胜感激。
解决方法
在您的 main
方法中,您正在调用 JavaApplication.classify
方法,该方法从分类器(双精度)返回分类。反过来,您通过 System.out.println
输出到 stdout。为什么你对数字来自哪里感到困惑?
Classifier.classifyInstance 方法返回一个双精度值。在回归算法的情况下,这是预测的数值,在分类算法的情况下,这是预测类别标签的基于 0 的索引。在您的情况下,数字 3.0
对应于 Jogging
。
这里重写了您的代码以解决以下缺点:
- 每次进行预测时都要重新创建数据集结构
- 每次进行预测时都加载模型
- 使用
weka.core.SparseInstance
没有意义,因为您只有三个属性,而且您似乎总是提供它们
从 Weka Explorer 或通过命令行保存模型时,Weka 不仅将 weka.classifiers.Classifier
对象存储在该文件中,还存储训练数据的标头(也称为结构)。您可以简单地将该标头用于您正在构建的 weka.core.Instance
对象的结构。
以下类希望获取模型的路径作为第一个参数,然后在构造函数中从该文件加载模型和标头。
我还为标题添加了一个 get 方法,以允许将分类标签索引(Classifier.classifyInstance
在名义类属性的情况下返回)转换为实际的类标签字符串。
public class JavaApplication {
private Instances header;
private Classifier model;
public JavaApplication(String modelPath) throws Exception {
Object[] objects = SerializationHelper.readAll(modelPath);
model = (Classifier) objects[0];
header = (Instances) objects[1];
}
public Instances getHeader() {
return header;
}
public int classify(double x,double y,double z) {
int result = -1;
try {
double[] values = new double[header.numAttributes()];
values[0] = x;
values[1] = y;
values[2] = z;
Instance inst = new DenseInstance(1.0,values);
inst.setDataset(header);
result = (int) model.classifyInstance(inst);
}
catch (Exception e) {
System.err.println("Failed to classify instance!");
e.printStackTrace();
}
return result;
}
public static void main(String[] args) throws Exception {
JavaApplication q = new JavaApplication(args[0]);
int labelIndex = q.classify(-1.18,12.76,1.7297841);
System.out.println(labelIndex);
System.out.println(q.getHeader().classAttribute().value(labelIndex));
}
}