Pytorch模型移植Android精确度严重下降

开发指纹分类预测的App时遇到了pytorch模型移植安卓端精确度严重下降的问题,大约从91%下降到不足5%。

解决bug过程中我尝试了多种方式,记录一下:

1.通道问题

由于训练时用的单通道的指纹灰度图,但是Android中

org.pytorch.torchvision.TensorImageUtils.bitmapToFloat32Tensor

这个函数默认返回3通道,所以尝试在这里调整。

一种方式是训练之前用脚本将所有训练的图片变为3通道,再训练。如此不需要改变Android端的代码。

import PIL.Image as Image
import os

path = r"E:\imageset"
save_path = r"D:\imgset"
for i in os.listdir(path):
    img = Image.open(os.path.join(path, i)).convert('RGB')
    img.save(os.path.join(save_path, i))

另一种方法是使用大佬重写的返回单通道的bitmapToFloat32Tensor

https://github.com/Unity05/pytorch/blob/master/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java但是我在使用时多次在scores数组中返回全空值。

final Tensor outputTensor = Objects.requireNonNull(module).forward(IValue.from(inputTensor)).toTensor();

final float[] scores = outputTensor.getDataAsFloatArray();

这两种方式我都实际操作了一遍,无效。

2.预处理问题

Android-pytorch官方提供的demo里对inputTensor的定义是这样的:

final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

需要注意的是其中

TORCHVISION_NORM_MEAN_RGB
TORCHVISION_NORM_STD_RGB

分别定义为

TORCHVISION_NORM_MEAN_RGB=[0.485, 0.456, 0.406], 
TORCHVISION_NORM_STD_RGB=[0.229, 0.224, 0.225]

如果在训练时参数和默认不一样,需要自定义参数,自定义的方式如下:

float[] image_mean = new float[]{"自定义"};
float[] image_std = new float[]{"自定义"};

 其余预处理同样需要同步。

这种方式尝试似乎将准确率提升了一点点。。大概提升了1%,基本没用。

3.尝试使用轻量模型MobileNetV3

这是另一个我怀疑的原因,但是实际测试预测准确率基本没有提升。

正解:

首先我们看一下Android-pytorch官网的模型移植的代码:

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")

这里显然是引导我们将pth格式文件转化为ptl格式

再看一下安卓端代码注释

 上面一行的引导是pt格式文件,下一行是ptl格式文件,因此我尝试了将pth文件转化为pt格式。

复制到安卓assets目录,调用预测,准确率正常。问题解决!

相关文章

学习编程是顺着互联网的发展潮流,是一件好事。新手如何学习...
IT行业是什么工作做什么?IT行业的工作有:产品策划类、页面...
女生学Java好就业吗?女生适合学Java编程吗?目前有不少女生...
Can’t connect to local MySQL server through socket \'/v...
oracle基本命令 一、登录操作 1.管理员登录 # 管理员登录 ...
一、背景 因为项目中需要通北京网络,所以需要连vpn,但是服...