如何将此 keras 模型tf 版本 1.15,动态 LSTM转换为 TFLite? 尝试 2 -

问题描述

我需要在 Keras 中构建一个与 tensorflow 1.15 兼容的 LSTM 网络。

这是我打开的另一篇关于如何创建网络的帖子。 keras LSTM model - a tf 1.15 equivalent that works with tflite

我使用的版本: tf 版本:1.15.0tf.keras 版本:2.2.4-tf

以下示例: https://github.com/tensorflow/tensorflow/tree/r1.15/tensorflow/lite/experimental/examples/lstm https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/lite/experimental/examples/lstm/TensorFlowLite_LSTM_Keras_Tutorial.ipynb

我设法为我需要的网络创建了一个代码

这是“源”网络(带有 keras.LSTM 层):

inputs = keras.Input(shape=(1,52))
state_1_h = keras.Input(shape=(200,))
state_1_c = keras.Input(shape=(200,))
x1,state_1_h_out,state_1_c_out = layers.LSTM(200,return_sequences=True,input_shape=(sequence_length,52),return_state=True)(inputs,initial_state=[state_1_h,state_1_c])
output = layers.Dense(13)(x1)

model = keras.Model([inputs,state_1_h,state_1_c],[output,state_1_c_out])

enter image description here

这就是我实现它的方式:

os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1'
from tensorflow.keras import Model

import tensorflow as tf
print(f"tf version: {tf.__version__},tf.keras version: {tf.keras.__version__}")
from tensorflow.keras.utils import plot_model

def buildLstmLayer(merged_inputs,num_units):
  inputs = merged_inputs[0]
  state_1_h_keras = merged_inputs[1]
  state_1_c_keras = merged_inputs[2]
  initial_state = tf.nn.rnn_cell.LSTMStateTuple(state_1_h_keras,state_1_c_keras)
  cell = tf.nn.rnn_cell.BasicLSTMCell(num_units,state_is_tuple=True)

  outputs,out_states = tf.lite.experimental.nn.dynamic_rnn(
      cell,inputs,dtype='float32',time_major=True,initial_state=initial_state)
  state_1_h_out,state_1_c_out = out_states
  state_1_h_out_keras = tf.keras.Input(tensor=(state_1_h_out),name='state_1_h_out')
  state_1_c_out_keras = tf.keras.Input(tensor=(state_1_c_out),name='state_1_c_out')
  return outputs,state_1_h_out_keras,state_1_c_out_keras

tf.reset_default_graph()

inputs = tf.keras.layers.Input(shape=(1,name='input')
batch_size = tf.shape(inputs)[1]
cell = tf.nn.rnn_cell.BasicLSTMCell(200,state_is_tuple=True)
initial_state = cell.zero_state(batch_size,tf.float32)
state_1_h,state_1_c = initial_state
state_1_h_keras = tf.keras.Input(tensor=(state_1_h),name='state_1_h')
state_1_c_keras = tf.keras.Input(tensor=(state_1_c),name='state_1_c')
x1,state_1_c_out_keras = tf.keras.layers.Lambda(buildLstmLayer,arguments={'num_units': 200})([inputs,state_1_h_keras,state_1_c_keras])
output = tf.keras.layers.Dense(13,activation=tf.nn.softmax,name='output')(x1)
model = Model([inputs,state_1_c_keras],state_1_c_out_keras])

enter image description here

现在,我需要将其转换为 TFLite。

首先,当我尝试保存并加载它时,它不起作用:

model.save("model_working.h5")
loaded_model = load_model("model_working.h5")

this notebook 中,他们展示了如何将其正确转换为 TFLite。

# Step 3: Convert the Keras model to TensorFlow Lite model.
sess = tf.keras.backend.get_session()
input_tensor = sess.graph.get_tensor_by_name('input:0')
output_tensor = sess.graph.get_tensor_by_name('output/softmax:0')
converter = tf.lite.TFLiteConverter.from_session(
    sess,[input_tensor],[output_tensor])
tflite = converter.convert()

我尝试在我的代码中遵循相同的模式:

sess = tf.keras.backend.get_session()
inputs_tensors = [sess.graph.get_tensor_by_name(tensor_name) for tensor_name in [x.name for x in model.inputs]]
outputs_tensors = [sess.graph.get_tensor_by_name(tensor_name) for tensor_name in [x.name for x in model.outputs]]

converter = tf.lite.TFLiteConverter.from_session(
    sess,inputs_tensors,outputs_tensors)

tflite = converter.convert()

但我收到:


2021-02-06 21:24:44.536897: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your cpu supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2021-02-06 21:24:44.550694: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f944418c2f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-02-06 21:24:44.550705: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host,Default Version
WARNING:tensorflow:From /Users/yonatab/PycharmProjects/LSTM_in_android/model_working_for_issue_convert.py:50: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.

2021-02-06 21:24:44.678534: I tensorflow/core/grappler/devices.cc:60] Number of eligible GPUs (core count >= 8,compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA support)
2021-02-06 21:24:44.678584: I tensorflow/core/grappler/clusters/single_machine.cc:356] Starting new session
2021-02-06 21:24:44.702930: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:786] Optimization results for grappler item: graph_to_optimize
2021-02-06 21:24:44.702943: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:788]   function_optimizer: Graph size after: 514 nodes (0),626 edges (0),time = 5.445ms.
2021-02-06 21:24:44.702946: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:788]   function_optimizer: Graph size after: 514 nodes (0),time = 6.09ms.
2021-02-06 21:24:44.702948: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:786] Optimization results for grappler item: lambda_rnn_while_cond_49
2021-02-06 21:24:44.702951: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:788]   function_optimizer: function_optimizer did nothing. time = 0.001ms.
2021-02-06 21:24:44.702953: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:788]   function_optimizer: function_optimizer did nothing. time = 0ms.
2021-02-06 21:24:44.702955: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:786] Optimization results for grappler item: lambda_rnn_while_body_50
2021-02-06 21:24:44.702958: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:788]   function_optimizer: function_optimizer did nothing. time = 0.001ms.
2021-02-06 21:24:44.702960: I tensorflow/core/grappler/optimizers/Meta_optimizer.cc:788]   function_optimizer: function_optimizer did nothing. time = 0ms.
WARNING:tensorflow:From /Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/lite/python/util.py:208: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
WARNING:tensorflow:From /Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Traceback (most recent call last):
  File "/Users/yonatab/PycharmProjects/LSTM_in_android/model_working_for_issue_convert.py",line 55,in <module>
    sess,outputs_tensors)
  File "/Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/lite/python/lite.py",line 628,in from_session
    graph_def = _freeze_graph(sess,input_tensors,output_tensors)
  File "/Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/lite/python/util.py",line 244,in freeze_graph
    hinted_outputs_nodes)
  File "/Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/lite/python/util.py",line 209,in _convert_op_hints_if_present
    graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
  File "/Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/lite/python/op_hint.py",line 1288,in convert_op_hints_to_stubs
    return _convert_op_hints_to_stubs_helper(graph_def,write_callback)
  File "/Users/yonatab/PycharmProjects/LSTM_in_android/venv_tf_115/lib/python3.7/site-packages/tensorflow_core/lite/python/op_hint.py",line 1179,in _convert_op_hints_to_stubs_helper
    assert (len(children_hints) > 0)  #  pylint: disable=g-explicit-length-test
AssertionError

尝试 2 -

如果我使用 cell = tf.lite.experimental.nn.TFLiteLSTMCell(num_units,state_is_tuple=True)

代替 cell = tf.nn.rnn_cell.BasicLSTMCell(num_units,state_is_tuple=True)

我收到:

tensorflow.lite.python.convert.ConverterError: See console for info.
2021-02-08 15:41:35.471411: F tensorflow/lite/toco/tooling_util.cc:935] Check Failed: GetopWithOutput(model,output_array) Specified output array "lambda/rnn/while/Identity_4" is not produced by any op in this graph. Is it a typo? This should not happen. If you trigger this error please send a bug report (with code to reporduce this error),to the TensorFlow Lite team.
Fatal Python error: Aborted

我错过了什么?

是否与this documentation explaination有关?

谢谢

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)