问题描述
我使用 mobilenet 为特定的训练数据集构建了一个模型。在用测试集测试我的模型时,在 keras (model.h5) 中生成的模型获得了大约 92% 的准确率。然后我使用以下代码将我的模型转换为 tflite:
model = tf.keras.models.load_model('modelos TensorflowLite/MobileNet.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("MobileNet.tflite","wb").write(tflite_model)
当使用 python 中的 tflite 解释器针对相同的测试集执行 tflite 模型时,我获得的准确度与使用 keras 模型获得的准确度非常相似,接近 92%。解释器中用于一次推理的代码:
interpreter = tf.lite.Interpreter(model_path="MobileNet.tflite")
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()
# lectura y procesamiento de imagen
img = cv2.imread("image.jpg")
new_img = cv2.resize(img,(300,300))
new_img = new_img.astype(np.float32)
new_img /= 255.
# input_details[0]['index'] = the index which accepts the input
interpreter.set_tensor(input_details[0]['index'],[new_img])
# realizar la prediccion del interprete
interpreter.invoke()
# output_details[0]['index'] = the index which provides the input
output_data = interpreter.get_tensor(output_details[0]['index'])
print("For file {},the output is {}".format(file.stem,output_data))
我在android studio中测试测试套件时出现问题。使用转换为 tflite 的相同模型,针对相同测试集的准确率为 39%。需要说明的是,该模型没有量化。我对 3 个类别中的每个类别获得的结果进行了单个图像比较。在这张图片中,类被正确分类为 keras 和 tflite 模型,但在 android 中没有:
概率 | keras 模型 .h5 | tflite py 解释器 | tflite 安卓 |
---|---|---|---|
概率。正确的类 | 9.6e-01 | 9.6e-01 | 3.2e-6 |
我的问题不是将 .h5 模型转换为 .tflite 时准确率低。我的问题是 tflite 模型在 python 解释器中工作正常,但在 android studio 中实现它时非常糟糕。
private TensorImage loadImage(Bitmap bitmap,int sensorOrientation) {
// Loads bitmap into a TensorImage.
inputimageBuffer.load(bitmap);
int noOfRotations = sensorOrientation / 90;
int cropSize = Math.min(bitmap.getWidth(),bitmap.getHeight());
ImageProcessor imageProcessor = new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize,cropSize))
.add(new ResizeOp(imageResizeX,imageResizeY,ResizeOp.ResizeMethod.BILINEAR))
.add(new Rot90Op(noOfRotations))
.add(new normalizeOp(IMAGE_MEAN,IMAGE_STD))
.build();
return imageProcessor.process(inputimageBuffer);
}
执行预测的代码:
inputimageBuffer = loadImage(bitmap,sensorOrientation);
tensorClassifier.run(inputimageBuffer.getBuffer(),probabilityImageBuffer.getBuffer().rewind());
import android.app.Activity;
import android.graphics.Bitmap;
import android.widget.Toast;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.normalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class ImageClassifier {
// Non-Quantized
private static final float PROBABILITY_MEAN = 0.0f;
private static final float PROBABILITY_STD = 1.0f;
private static final float IMAGE_STD = 127.5f;
private static final float IMAGE_MEAN = 127.5f;
private static final int MAX_SIZE =3;
/**
* Image size along the x axis.
*/
private final int imageResizeX;
/**
* Image size along the y axis.
*/
private final int imageResizeY;
/**
* Labels corresponding to the output of the vision model.
*/
private final List<String> labels;
/**
* An instance of the driver class to run model inference with Tensorflow Lite.
*/
private final Interpreter tensorClassifier;
/**
* Input image TensorBuffer.
*/
private TensorImage inputimageBuffer;
/**
* Output probability TensorBuffer.
*/
private final TensorBuffer probabilityImageBuffer;
/**
* Processer to apply post processing of the output probability.
*/
private final TensorProcessor probabilityProcessor;
/**
* Creates a classifier
*
* @param activity the current activity
* @throws IOException
*/
public ImageClassifier(Activity activity) throws IOException {
/*
* The loaded TensorFlow Lite model.
*/
MappedByteBuffer classifierModel = FileUtil.loadMappedFile(activity,"MobileNet.tflite");
// Loads labels out from the label file.
labels = FileUtil.loadLabels(activity,"labels_mobilenet.txt");
tensorClassifier = new Interpreter(classifierModel,null);
// Reads type and shape of input and output tensors,respectively. [START]
int imageTensorIndex = 0; // input
int probabilityTensorIndex = 0;// output
int[] inputimageShape = tensorClassifier.getInputTensor(imageTensorIndex).shape();
DataType inputDataType = tensorClassifier.getInputTensor(imageTensorIndex).dataType();
int[] outputimageShape = tensorClassifier.getoutputTensor(probabilityTensorIndex).shape();
DataType outputDataType = tensorClassifier.getoutputTensor(probabilityTensorIndex).dataType();
imageResizeX = inputimageShape[2];
imageResizeY = inputimageShape[1];
// Creates the input tensor.
inputimageBuffer = new TensorImage(inputDataType);
// Creates the output tensor and its processor.
probabilityImageBuffer = TensorBuffer.createFixedSize(outputimageShape,outputDataType);
// Creates the post processor for the output probability.
probabilityProcessor = new TensorProcessor.Builder().add(new normalizeOp(PROBABILITY_MEAN,PROBABILITY_STD))
.build();
}
/**
* method runs the inference and returns the classification results
*
* @param bitmap the bitmap of the image
* @param sensorOrientation orientation of the camera
* @return classification results
*/
public List<Recognition> recognizeImage(final Bitmap bitmap,final int sensorOrientation) {
// Lista con labels y probabilidades de cada clase
List<Recognition> recognitions = new ArrayList<>();
inputimageBuffer = loadImage(bitmap,sensorOrientation);
tensorClassifier.run(inputimageBuffer.getBuffer(),probabilityImageBuffer.getBuffer().rewind()); ///
// Gets the map of label and probability.
Map<String,Float> labelledProbability = new TensorLabel(labels,probabilityProcessor.process(probabilityImageBuffer)).getMapWithFloatValue();
int idLabel = 0;
for (Map.Entry<String,Float> entry : labelledProbability.entrySet()) {
recognitions.add(new Recognition(String.valueOf(idLabel),entry.getValue()));
idLabeL++;
}
// Lista con probabilidades de cada clase
List<Float> probabilidades = new ArrayList<>();
for (Map.Entry<String,Float> entry : labelledProbability.entrySet()) {
probabilidades.add(entry.getValue());
}
Collections.sort(recognitions);
return recognitions.subList(0,MAX_SIZE);
}
/**
* loads the image into tensor input buffer and apply pre processing steps
*
* @param bitmap the bit map to be loaded
* @param sensorOrientation the sensor orientation
* @return the image loaded tensor input buffer
*/
private TensorImage loadImage(Bitmap bitmap,int sensorOrientation) {
// Loads bitmap into a TensorImage.
inputimageBuffer.load(bitmap);
int noOfRotations = sensorOrientation / 90;
int cropSize = Math.min(bitmap.getWidth(),bitmap.getHeight());
// pre processing steps are applied here
ImageProcessor imageProcessor = new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize,cropSize))
.add(new ResizeOp(imageResizeX,ResizeOp.ResizeMethod.BILINEAR))
.add(new Rot90Op(noOfRotations))
.add(new normalizeOp(IMAGE_MEAN,IMAGE_STD))
.build();
return imageProcessor.process(inputimageBuffer);
}
/**
* An immutable result returned by a Classifier describing what was recognized.
*/
public class Recognition implements Comparable {
/**
* display name for the recognition.
*/
private String name;
/**
* A sortable score for how good the recognition is relative to others. Higher should be better.
*/
private float confidence;
public Recognition() {
}
public Recognition(String name,float confidence) {
this.name = name;
this.confidence = confidence;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public float getConfidence() {
return confidence;
}
public void setConfidence(float confidence) {
this.confidence = confidence;
}
@Override
public String toString() {
return "Recognition{" +
"name='" + name + '\'' +
",confidence=" + confidence +
'}';
}
@Override
public int compareto(Object o) {
return Float.compare(((Recognition) o).confidence,this.confidence);
}
}
}
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)