使用@tf.function 进行自定义张量流训练的内存泄漏 tl;博士;详情

问题描述

我正在尝试按照官方 Keras 演练为 TF2/Keras 编写自己的训练循环。 vanilla 版本的效果很好,但是当我尝试将 @tf.function 装饰器添加到我的训练步骤时,一些内存泄漏占用了我所有的内存并且我失去了对我的机器的控制,有谁知道发生了什么?

代码的重要部分如下所示:

@tf.function
def train_step(x,y):
    with tf.GradientTape() as tape:
        logits = siamese_network(x,training=True)
        loss_value = loss_fn(y,logits)
    grads = tape.gradient(loss_value,siamese_network.trainable_weights)
    optimizer.apply_gradients(zip(grads,siamese_network.trainable_weights))
    train_acc_metric.update_state(y,logits)
    return loss_value

@tf.function
def test_step(x,y):
    val_logits = siamese_network(x,training=False)
    val_acc_metric.update_state(y,val_logits)
    val_prec_metric.update_state(y_batch_val,val_logits)
    val_rec_metric.update_state(y_batch_val,val_logits)


for epoch in range(epochs):
        step_time = 0
        epoch_time = time.time()
        print("Start of {} epoch".format(epoch))
        for step,(x_batch_train,y_batch_train) in enumerate(train_ds):
            if step > steps_epoch:
                break
           
            loss_value = train_step(x_batch_train,y_batch_train)
        train_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        
        for val_step,(x_batch_val,y_batch_val) in enumerate(test_ds):
            if val_step>validation_steps:
                break
            test_step(x_batch_val,y_batch_val)
         
        val_acc = val_acc_metric.result()
        val_prec = val_prec_metric.result()
        val_rec = val_rec_metric.result()

        val_acc_metric.reset_states()
        val_prec_metric.reset_states()
        val_rec_metric.reset_states()

如果我对 @tf.function 行进行注释,则不会发生内存泄漏,但步骤时间慢了 3 倍。我的猜测是以某种方式在每个时期内再次创建了该图或类似的东西,但我不知道如何解决它。

这是我正在学习的教程:https://keras.io/guides/writing_a_training_loop_from_scratch/

解决方法

tl;博士;

TensorFlow 可能会为传递给修饰函数的每组唯一参数值生成一个新图。确保将形状一致的 Tensor 对象传递给 test_steptrain_step 而不是 python 对象。

详情

这是暗中刺伤。虽然我从未尝试过 @tf.function,但我确实在 the documentation 中发现了以下警告:

tf.function 还将任何纯 Python 值视为不透明对象,并为其遇到的每组 Python 参数构建一个单独的图。

注意:将 python 标量或列表作为参数传递给 tf.function 将始终构建一个新图。为避免这种情况,请尽可能将数字参数作为张量传递

最后:

函数通过从输入的 args 和 kwargs 计算缓存键来确定是否重用跟踪的 ConcreteFunction。缓存键是根据以下规则(可能会更改)根据函数调用的输入 args 和 kwargs 标识 ConcreteFunction 的键:

  • 生成的键对于 tf.Tensor 是它的形状和数据类型。
  • 为 tf.Variable 生成的键是唯一的变量 id。
  • 为 Python 原语(如 int、 float,str) 是它的值。
  • 为嵌套字典、列表、元组、namedtuples 和 attrs 生成的键是叶键的扁平元组(参见 nest.flatten)。 (由于这种扁平化,调用具有与跟踪期间使用的嵌套结构不同的嵌套结构的具体函数将导致 TypeError)。
  • 对于所有其他 Python 类型,键对于对象是唯一的。这样,一个函数或方法会针对每个调用它的实例独立跟踪。

我从这一切中得到的是,如果您没有将大小一致的 Tensor 对象传递给 @tf.function 化的函数(也许您使用 Python 集合或原语),您很可能正在使用您传入的每个不同参数值创建函数的新图形版本。我猜这可能会导致您看到的内存爆炸行为。我不知道您的 test_dstrain_ds 对象是如何创建的,但您可能希望确保它们的创建方式使得 enumerate(blah_ds) 像教程中那样返回张量,或者在在传递给 test_steptrain_step 函数之前,至少将值转换为张量。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...