将 tf1 中的代码转换为 tf2 时出错

问题描述

值在哪里

rnn_size: 512
batch_size: 128


rnn_inputs: Tensor("embedding_lookup/Identity_1:0",shape=(?,?,128),dtype=float32)
sequence_length: Tensor("inputs_length:0",),dtype=int32)
cell_fw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb6d0>
cell_bw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb910>

通过获取 enc_state 值

enc_output,enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,rnn_inputs,sequence_length,dtype=tf.float32)

enc_state 值在哪里

enc_state: LSTMStateTuple(c=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?,512) dtype=float32>,h=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(?,512) dtype=float32>)

TF1 代码

initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,_zero_state_tensors(rnn_size,batch_size,tf.float32))

通过

转换成TF2
initial_state = tfa.seq2seq.AttentionWrapper(enc_state,tf.float32))

获取错误


TypeError                                 Traceback (most recent call last)
<ipython-input-54-d87646b9df5d> in <module>()
      8                                                     threshold) 
      9             model = build_graph(keep_probability,rnn_size,num_layers,---> 10                                 learning_rate,embedding_size,direction)
     11             train(model,epochs,log_string)

6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname,value,expected_type,memo)
    596                 raise TypeError(
    597                     'type of {} must be {}; got {} instead'.
--> 598                     format(argname,qualified_name(expected_type),qualified_name(value)))
    599     elif isinstance(expected_type,TypeVar):
    600         # Only happens on < 3.6

TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead

还可以解释错误的最后一行,即

    TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)