为什么我的 tflite 模型使用 Python 解释器可以很好地预测,但在 Android Studio 中部署时却很差?

问题描述

我使用 mobilenet 为特定的训练数据集构建了一个模型。在用测试集测试我的模型时,在 keras (model.h5) 中生成的模型获得了大约 92% 的准确率。然后我使用以下代码将我的模型转换为 tflite:

model = tf.keras.models.load_model('modelos TensorflowLite/MobileNet.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("MobileNet.tflite","wb").write(tflite_model)

当使用 python 中的 tflite 解释器针对相同的测试集执行 tflite 模型时,我获得的准确度与使用 keras 模型获得的准确度非常相似,接近 92%。解释器中用于一次推理的代码

interpreter = tf.lite.Interpreter(model_path="MobileNet.tflite")
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()

    
    # lectura y procesamiento de imagen
    img = cv2.imread("image.jpg")
    new_img = cv2.resize(img,(300,300))
    new_img = new_img.astype(np.float32)
    new_img /= 255.
    
    # input_details[0]['index'] = the index which accepts the input
    interpreter.set_tensor(input_details[0]['index'],[new_img])
    
    # realizar la prediccion del interprete
    interpreter.invoke()
    
    # output_details[0]['index'] = the index which provides the input
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    print("For file {},the output is {}".format(file.stem,output_data))

我在android studio中测试测试套件时出现问题。使用转换为 tflite 的相同模型,针对相同测试集的准确率为 39%。需要说明的是,该模型没有量化。我对 3 个类别中的每个类别获得的结果进行了单个图像比较。在这图片中,类被正确分类为 keras 和 tflite 模型,但在 android 中没有:

概率 keras 模型 .h5 tflite py 解释器 tflite 安卓
概率。正确的类 9.6e-01 9.6e-01 3.2e-6

我的问题不是将 .h5 模型转换为 .tflite 时准确率低。我的问题是 tflite 模型在 python 解释器中工作正常,但在 android studio 中实现它时非常糟糕。

加载图片代码

private TensorImage loadImage(Bitmap bitmap,int sensorOrientation) {
    // Loads bitmap into a TensorImage.
    inputimageBuffer.load(bitmap);

    int noOfRotations = sensorOrientation / 90;
    int cropSize = Math.min(bitmap.getWidth(),bitmap.getHeight());

    ImageProcessor imageProcessor = new ImageProcessor.Builder()
            .add(new ResizeWithCropOrPadOp(cropSize,cropSize))
            .add(new ResizeOp(imageResizeX,imageResizeY,ResizeOp.ResizeMethod.BILINEAR))
            .add(new Rot90Op(noOfRotations))
            .add(new normalizeOp(IMAGE_MEAN,IMAGE_STD))
            .build();
    return imageProcessor.process(inputimageBuffer);
}

执行预测的代码

inputimageBuffer = loadImage(bitmap,sensorOrientation);
tensorClassifier.run(inputimageBuffer.getBuffer(),probabilityImageBuffer.getBuffer().rewind());

所有分类代码(ImageCLassifier.java):

import android.app.Activity;
import android.graphics.Bitmap;
import android.widget.Toast;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.normalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class ImageClassifier {

    // Non-Quantized
    private static final float PROBABILITY_MEAN = 0.0f;
    private static final float PROBABILITY_STD = 1.0f;

    private static final float IMAGE_STD = 127.5f;
    private static final float IMAGE_MEAN = 127.5f;

    private static final int MAX_SIZE =3;

    /**
     * Image size along the x axis.
     */
    private final int imageResizeX;
    /**
     * Image size along the y axis.
     */
    private final int imageResizeY;

    /**
     * Labels corresponding to the output of the vision model.
     */
    private final List<String> labels;

    /**
     * An instance of the driver class to run model inference with Tensorflow Lite.
     */
    private final Interpreter tensorClassifier;
    /**
     * Input image TensorBuffer.
     */
    private TensorImage inputimageBuffer;
    /**
     * Output probability TensorBuffer.
     */
    private final TensorBuffer probabilityImageBuffer;
    /**
     * Processer to apply post processing of the output probability.
     */
    private final TensorProcessor probabilityProcessor;

    /**
     * Creates a classifier
     *
     * @param activity the current activity
     * @throws IOException
     */
    public ImageClassifier(Activity activity) throws IOException {
        /*
         * The loaded TensorFlow Lite model.
         */
        MappedByteBuffer classifierModel = FileUtil.loadMappedFile(activity,"MobileNet.tflite");
        // Loads labels out from the label file.
        labels = FileUtil.loadLabels(activity,"labels_mobilenet.txt");

        tensorClassifier = new Interpreter(classifierModel,null);

        // Reads type and shape of input and output tensors,respectively. [START]
        int imageTensorIndex = 0; // input
        int probabilityTensorIndex = 0;// output

        int[] inputimageShape = tensorClassifier.getInputTensor(imageTensorIndex).shape();
        DataType inputDataType = tensorClassifier.getInputTensor(imageTensorIndex).dataType();

        int[] outputimageShape = tensorClassifier.getoutputTensor(probabilityTensorIndex).shape();
        DataType outputDataType = tensorClassifier.getoutputTensor(probabilityTensorIndex).dataType();

        imageResizeX = inputimageShape[2];
        imageResizeY = inputimageShape[1];


        // Creates the input tensor.
        inputimageBuffer = new TensorImage(inputDataType);

        // Creates the output tensor and its processor.
        probabilityImageBuffer = TensorBuffer.createFixedSize(outputimageShape,outputDataType);

        // Creates the post processor for the output probability.
        probabilityProcessor = new TensorProcessor.Builder().add(new normalizeOp(PROBABILITY_MEAN,PROBABILITY_STD))
                .build();
    }

    /**
     * method runs the inference and returns the classification results
     *
     * @param bitmap            the bitmap of the image
     * @param sensorOrientation orientation of the camera
     * @return classification results
     */
    public List<Recognition> recognizeImage(final Bitmap bitmap,final int sensorOrientation) {
        // Lista con labels y probabilidades de cada clase
        List<Recognition> recognitions = new ArrayList<>();

        inputimageBuffer = loadImage(bitmap,sensorOrientation);
        tensorClassifier.run(inputimageBuffer.getBuffer(),probabilityImageBuffer.getBuffer().rewind()); ///

        // Gets the map of label and probability.
        Map<String,Float> labelledProbability = new TensorLabel(labels,probabilityProcessor.process(probabilityImageBuffer)).getMapWithFloatValue();

        int idLabel = 0;
        for (Map.Entry<String,Float> entry : labelledProbability.entrySet()) {
            recognitions.add(new Recognition(String.valueOf(idLabel),entry.getValue()));
            idLabeL++;
        }        

        // Lista con probabilidades de cada clase
        List<Float> probabilidades = new ArrayList<>();
        for (Map.Entry<String,Float> entry : labelledProbability.entrySet()) {
            probabilidades.add(entry.getValue());
        }

        Collections.sort(recognitions);

        return recognitions.subList(0,MAX_SIZE);
    }

    /**
     * loads the image into tensor input buffer and apply pre processing steps
     *
     * @param bitmap            the bit map to be loaded
     * @param sensorOrientation the sensor orientation
     * @return the image loaded tensor input buffer
     */
    private TensorImage loadImage(Bitmap bitmap,int sensorOrientation) {
        // Loads bitmap into a TensorImage.
        inputimageBuffer.load(bitmap);

        int noOfRotations = sensorOrientation / 90;
        int cropSize = Math.min(bitmap.getWidth(),bitmap.getHeight());

        // pre processing steps are applied here
        ImageProcessor imageProcessor = new ImageProcessor.Builder()
                .add(new ResizeWithCropOrPadOp(cropSize,cropSize))
                .add(new ResizeOp(imageResizeX,ResizeOp.ResizeMethod.BILINEAR))
                .add(new Rot90Op(noOfRotations))
                .add(new normalizeOp(IMAGE_MEAN,IMAGE_STD))
                .build();
        return imageProcessor.process(inputimageBuffer);
    }

    /**
     * An immutable result returned by a Classifier describing what was recognized.
     */
    public class Recognition implements Comparable {
        /**
         * display name for the recognition.
         */
        private String name;
        /**
         * A sortable score for how good the recognition is relative to others. Higher should be better.
         */
        private float confidence;

        public Recognition() {
        }

        public Recognition(String name,float confidence) {
            this.name = name;
            this.confidence = confidence;
        }

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public float getConfidence() {
            return confidence;
        }

        public void setConfidence(float confidence) {
            this.confidence = confidence;
        }

        @Override
        public String toString() {
            return "Recognition{" +
                    "name='" + name + '\'' +
                    ",confidence=" + confidence +
                    '}';
        }

        @Override
        public int compareto(Object o) {
            return Float.compare(((Recognition) o).confidence,this.confidence);
        }
    }


}

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...