将Tensorflow1.12模型转换为Tensorflow LiteTFLite

问题描述

我正在尝试将在Tensorflow 1.12中创建的模型转换为Tensorflow Lite。

我使用以下代码:

import numpy as np
import tensorflow as tf

# Generate tf.keras model.
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(2,input_shape=(3,)))
model.add(tf.keras.layers.RepeatVector(3))
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)))
model.compile(loss=tf.keras.losses.MSE,optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),metrics=[tf.keras.metrics.categorical_accuracy],sample_weight_mode='temporal')

x = np.random.random((1,3))
y = np.random.random((1,3,3))
model.train_on_batch(x,y)
model.predict(x)

# Save tf.keras model in HDF5 format.
keras_file = "keras_model.h5"
tf.keras.models.save_model(model,keras_file)

# Convert to TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
open("converted_model.tflite","wb").write(tflite_model)

我从https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/python_api.md#pre_tensorflow_1.12网站上获得了此代码示例。 因为我使用的是Tensorflow 1.12,所以我修改了该行

converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)

converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file)

如上面链接中所建议。 当我运行此代码时,我得到了以下信息:

INFO:tensorflow:Froze 4 variables.
INFO:tensorflow:Converted 4 variables to const ops.

然后我收到此错误:


------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-81a9e7060f2c> in <module>
     23 # Convert to TensorFlow Lite model.
     24 converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(keras_file)
---> 25 tflite_model = converter.convert()
     26 open("converted_model.tflite","wb").write(tflite_model)

~\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\python\lite.py in convert(self)
    451           input_tensors=self._input_tensors,452           output_tensors=self._output_tensors,--> 453           **converter_kwargs)
    454     else:
    455       # Graphs without valid tensors cannot be loaded into tf.Session since they

~\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\python\convert.py in toco_convert_impl(input_data,input_tensors,output_tensors,*args,**kwargs)
    340   data = toco_convert_protos(model_flags.SerializeToString(),341                              toco_flags.SerializeToString(),--> 342                              input_data.SerializeToString())
    343   return data
    344 

~\Anaconda3\envs\tensorflow1.12\lib\site-packages\tensorflow\contrib\lite\python\convert.py in toco_convert_protos(model_flags_str,toco_flags_str,input_data_str)
    133     else:
    134       raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
--> 135                          (stdout,stderr))
    136 
    137 

RuntimeError: TOCO failed see console for info.
b'C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:523: FutureWarning: Passing (type,1) or \'1type\' as a synonym of type is deprecated; in a future version of numpy,it will be understood as (type,(1,)) / \'(1,)type\'.\r\n  _np_qint8 = np.dtype([("qint8",np.int8,1)])\r\nC:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:524: FutureWarning: Passing (type,)type\'.\r\n  _np_quint8 = np.dtype([("quint8",np.uint8,1)])\r\nC:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:525: FutureWarning: Passing (type,)type\'.\r\n  _np_qint16 = np.dtype([("qint16",np.int16,1)])\r\nC:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:526: FutureWarning: Passing (type,)type\'.\r\n  _np_quint16 = np.dtype([("quint16",np.uint16,1)])\r\nC:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:527: FutureWarning: Passing (type,)type\'.\r\n  _np_qint32 = np.dtype([("qint32",np.int32,1)])\r\nC:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:532: FutureWarning: Passing (type,)type\'.\r\n  np_resource = np.dtype([("resource",np.ubyte,1)])\r\nTraceback (most recent call last):\r\n  File "C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py",line 18,in swig_import_helper\r\n    fp,pathname,description = imp.find_module(\'_tensorflow_wrap_toco\',[dirname(__file__)])\r\n  File "C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\imp.py",line 297,in find_module\r\n    raise ImportError(_ERR_MSG.format(name),name=name)\r\nImportError: No module named \'_tensorflow_wrap_toco\'\r\n\r\nDuring handling of the above exception,another exception occurred:\r\n\r\nTraceback (most recent call last):\r\n  File "C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\Scripts\\toco_from_protos-script.py",line 6,in <module>\r\n    from tensorflow.contrib.lite.toco.python.toco_from_protos import main\r\n  File "C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\toco_from_protos.py",line 22,in <module>\r\n    from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco\r\n  File "C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py",line 28,in <module>\r\n    _tensorflow_wrap_toco = swig_import_helper()\r\n  File "C:\\Users\\.\\Anaconda3\\envs\\tensorflow1.12\\lib\\site-packages\\tensorflow\\contrib\\lite\\toco\\python\\tensorflow_wrap_toco.py",line 20,in swig_import_helper\r\n    import _tensorflow_wrap_toco\r\nModuleNotFoundError: No module named \'_tensorflow_wrap_toco\'\r\n'
None

Could someone help to solve this?

解决方法

我建议您使用较新的张量流并使用其新转换器(称为MLIR,而不是TOCO)。

在2.4.0(但也可以与2.2.x一起使用)和稍微修改的代码行中尝试过代码:

converter = tf.lite.TFLiteConverter.from_keras_model(model)

并获得* .tflite模型。

,

根据我的实验,tf2.x与keras更加友好,可以让您进行完美的量化。但是对于tf1.x,您应该切换到导出到QAT或冻结图def以确保量化。 对于QAT:您可以签入here

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...