问题描述
我想使用 tf-hub 构建文本分类模型并导出为 tflite 模型但是
在转换包括 tf hub 的 tensorflow 模型时,我遇到了错误。请帮我解决。
import tensorflow as tf
import tensorflow_hub as hub
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(dtype=tf.string,input_shape=()))
model.add(hub.KerasLayer("https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1"))
converter=tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
我尝试了 tf-lite python 和命令行 api。但是我遇到了 InvalidArgumentError。
InvalidArgumentError Traceback (most recent call last)
<ipython-input-15-5a8dbd778645> in <module>()
5 model.add(hub.KerasLayer("https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1"))
6 converter = tf.lite.TFLiteConverter.from_keras_model(model)
----> 7 tflite_model = converter.convert()
6 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py in convert(self)
850 frozen_func,graph_def = (
851 _convert_to_constants.convert_variables_to_constants_v2_as_graph(
--> 852 self._funcs[0],lower_control_flow=False))
853
854 input_tensors = [
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in convert_variables_to_constants_v2_as_graph(func,lower_control_flow,aggressive_inlining)
1103 func=func,1104 lower_control_flow=lower_control_flow,-> 1105 aggressive_inlining=aggressive_inlining)
1106
1107 output_graph_def,converted_input_indices = _replace_variables_by_constants(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in __init__(self,func,aggressive_inlining,variable_names_allowlist,variable_names_denylist)
804 variable_names_allowlist=variable_names_allowlist,805 variable_names_denylist=variable_names_denylist)
--> 806 self._build_tensor_data()
807
808 def _build_tensor_data(self):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/convert_to_constants.py in _build_tensor_data(self)
823 data = map_index_to_variable[idx].numpy()
824 else:
--> 825 data = val_tensor.numpy()
826 self._tensor_data[tensor_name] = _TensorData(
827 numpy=data,/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in numpy(self)
1069 """
1070 # Todo(slebedev): Consider avoiding a copy for non-cpu or remote tensors.
-> 1071 maybe_arr = self._numpy() # pylint: disable=protected-access
1072 return maybe_arr.copy() if isinstance(maybe_arr,np.ndarray) else maybe_arr
1073
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
1037 return self._numpy_internal()
1038 except core._NotOkStatusException as e: # pylint: disable=protected-access
-> 1039 six.raise_from(core._status_to_exception(e.code,e.message),None) # pylint: disable=protected-access
1040
1041 @property
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value,from_value)
InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.
解决方法
上次我检查时,TFLite 不支持查找表,这是 TF Hub 模型中资源张量的主要来源(除了变量,但这些肯定有效)。