问题描述
现在我正在尝试将 SavedModel 转换为 TFLite,以便在树莓派上使用。该模型是在自定义数据集上训练的 MobileNet Object Detection。 SavedModel 完美运行,并保留了 (1,150,3)
的相同形状。但是,当我使用以下代码将其转换为 TFLite 模型时:
import tensorflow as tf
saved_model_dir = input("Model dir: ")
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('model.tflite','wb') as f:
f.write(tflite_model)
并运行此代码以运行解释器:
import numpy as np
import tensorflow as tf
from PIL import Image
from os import listdir
from os.path import isfile,join
from random import choice,random
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
print(f"required input shape: {input_shape}")
我得到 [1 1 1 3]
的输入形状,因此我无法使用 150x150 图像作为输入。
我在 Python 3.7.10 和 Windows 10 上使用 Tensorflow 2.4。
我该如何解决这个问题?
解决方法
您可以依靠 TFLite 转换器 V1 API 来设置输入形状。请查看 https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter 中的 input_shapes 参数。