内部错误:在 tf.GradientTape 内运行 TF HUB BERT 时,试图在没有处理数据的情况下获取变量的梯度或类似

问题描述

我正在尝试在 Tensorflow 2.4 中的持久梯度磁带中训练 bert_en_uncased_L-12_H-768_A-12 TF HUB 模型。以下是我的代码的简化版本。

import tensorflow as tf
import tensorflow_hub as hub

input_mask = tf.keras.layers.Input(shape=4,dtype=tf.int32)
input_word_ids = tf.keras.layers.Input(shape=4,dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=4,dtype=tf.int32)

bert = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",trainable=True)({"input_mask": input_mask,'input_word_ids': input_type_ids,"input_type_ids": input_type_ids})

dense = tf.keras.layers.Dense(units=1)(bert['pooled_output'])

encode = tf.keras.models.Model([input_mask,input_word_ids,input_type_ids],dense)
import numpy as np

data = np.zeros((1,4))


@tf.function
def run():
    with tf.GradientTape(persistent=True,watch_accessed_variables=False) as tape:
        tape.watch(encode.trainable_weights)
        encode([data,data,data],training=True)


run()

错误

  raise ValueError("Internal error: Tried to take gradients (or similar) "

    ValueError: Internal error: Tried to take gradients (or similar) of a variable without handle data:
    Tensor("StatefulPartitionedCall:1079",dtype=resource)

错误仅在

  • 使用了 TF HUB trainable=True 选项
  • 使用了持久梯度胶带。 这是 TensorFlow 中的错误还是我尝试了不受支持内容

解决方法

我认为没有必要使用 persistent=True,您应该将其设为 False。通常,当我们需要计算 True 范围内的损失时,它被设置为 tape,以便我们可以在范围 src 之外计算它们的梯度。在你上面的代码示例中,我认为你不需要这个。

您的代码中可能需要修复的另一个错字。它有一个错误的输入映射。

bert = hub.KerasLayer(
    "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3",trainable=True)(
        {
           "input_mask": input_mask,'input_word_ids': input_word_ids,# < ---------
           "input_type_ids": input_type_ids
         }
     )

运行具有这些更改的代码

import numpy as np

data = np.zeros((1,4))

@tf.function
def run():
    with tf.GradientTape( watch_accessed_variables=False) as tape:
        tape.watch(encode.trainable_weights)
        y = encode([data,data,data],training=True)
    tf.print(y)

run()

# [[-0.799545228]]
,

在保存和加载 SavedModel 时,为 SavedModel 使用持久性 GradientTapes 需要 TensorFlow 2.5+。请继续关注 https://github.com/tensorflow/hub/issues/622 以获取有关 TF2.5 发布的更新以及为 BERT 等更新的 SavedModels。

M. Innat 的回答解释了如何通过使用标准的非持久性 GradientTapes 来避免该问题。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...