在循环中创建模型使 Keras 越来越慢

问题描述

当我多次训练一个模型时,即使所有相关数量都是在 for 循环内创建的,训练迭代也会变慢(因此每次都应该被覆盖,这应该足以避免创建增长计算图或填满内存或其他什么)。

这是我使用的代码

import numpy as np
import tensorflow as tf
import time

n_samples = 300000
n_features = 100
n_targets = 5
batch_size = 100
x = np.array(range(n_samples * n_features),dtype=np.float64).reshape((n_samples,n_features))
y = np.array(range(n_samples * n_targets),n_targets))
for t_idx in range(10):
    dataset = [x,y]
    dataset = tf.data.Dataset.from_tensor_slices(tuple(dataset)).shuffle(n_samples).repeat().batch(batch_size=batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    data_iterator = iter(dataset)

    inputs = tf.keras.Input(shape=(n_features,),name='input')
    outputs = tf.keras.layers.Dense(n_features,name='dense_1',activation=tf.keras.activations.relu)(inputs)
    outputs = tf.keras.layers.Dense(n_features,name='dense_2',activation=tf.keras.activations.relu)(outputs)
    outputs = tf.keras.layers.Dense(n_features,name='dense_3',name='dense_4',activation=tf.keras.activations.relu)(outputs)
    outputs = tf.keras.layers.Dense(n_targets,name='output',activation=tf.keras.activations.linear)(outputs)
    model = tf.keras.Model(inputs=inputs,outputs=outputs)

    trainable_variables = list(model.trainable_variables)

    adam_opt = tf.optimizers.Adam(learning_rate=0.001)


    @tf.function
    def loss(batch):
        x_,y_ = batch
        y_pred_ = model(x_)
        return tf.keras.losses.MSE(y_pred_,y_)


    @tf.function
    def optimization_step():
        batch = next(data_iterator)
        def f(): return loss(batch)
        adam_opt.minimize(f,var_list=trainable_variables)

    iterations = 50000
    loop_start = time.time()
    optimization_times = []
    for idx in range(iterations):
        optimization_step()

    loop_end = time.time()
    print(f'Elapsed: {loop_end - loop_start}')

我知道我可以使用 model.fit() 来训练模型,但这对我来说不是一个选项(上面的代码是通过剥离更复杂的代码获得的,我唯一的选择是直接调用优化器)

输出

Elapsed: 49.798316955566406
Elapsed: 55.18571472167969
Elapsed: 58.57510209083557
Elapsed: 64.41855955123901
Elapsed: 66.76858448982239
Elapsed: 68.3305652141571
Elapsed: 67.73438382148743
Elapsed: 69.73751258850098
Elapsed: 73.59102845191956
Elapsed: 73.14124798774719

我预计每次训练所用的时间大致相同,但有明显的上升趋势。 我做错了什么?

系统信息

  • 操作系统:Windows 10
  • python 版本:3.7.9
  • 张量流版本:2.3.1

编辑

根据@Nicolas Gervais 和@M.Innat 的回复,我一直在不同机器上试验代码的两个版本(我的和@M.Innat 的)。这是我发现的:

  • 在我测试过的两台不同的 Windows 机器上(相同的 Windows 版本、相同的软件包版本、相同的 python 版本,在干净的 virtualenv 上),结果有很大的不同;两者中最好的仍然显示出训练时间的上升趋势,但比我最初测试的要小;两者中最好的结果如下:
    WITH MY CODE:
    Elapsed: 49.35429096221924                                                                                                                                                                                                                   
    Elapsed: 52.551310777664185                                                                                                                                                                                                                  
    Elapsed: 54.324320554733276                                                                                                                                                                                                                  
    Elapsed: 56.53233051300049                                                                                                                                                                                                                   
    Elapsed: 56.81632399559021                                                                                                                                                                                                                   
    Elapsed: 58.70533752441406                                                                                                                                                                                                                   
    Elapsed: 59.68834161758423                                                                                                                                                                                                                   
    Elapsed: 61.419353008270264                                                                                                                                                                                                                  
    Elapsed: 60.33834195137024                                                                                                                                                                                                                   
    Elapsed: 62.536344051361084

    WITH @M.Innat CODE:
    Elapsed: 50.51127886772156                                                                                                                                                                                                                   
    Elapsed: 53.94429612159729                                                                                                                                                                                                                   
    Elapsed: 52.38828897476196                                                                                                                                                                                                                   
    Elapsed: 54.5512957572937                                                                                                                                                                                                                    
    Elapsed: 58.1543083190918                                                                                                                                                                                                                    
    Elapsed: 61.21232509613037                                                                                                                                                                                                                   
    Elapsed: 60.11531925201416                                                                                                                                                                                                                   
    Elapsed: 59.95942974090576                                                                                                                                                                                                                   
    Elapsed: 60.48531889915466                                                                                                                                                                                                                   
    Elapsed: 59.37330341339111
  • 如您所见,时间看起来比我最初在第二台 Windows 机器上发布的要稳定一些(但请注意,当我在用于初始测试的机器上进行测试时,它们基本相同);尽管如此,趋势仍然清晰可见;这很奇怪,因为这在@M.Innit 发布的结果中没有发生;
  • 这个问题在 Linux 或 Mac OS 上都不会出现。

我开始觉得这是 Tensorflow 中的一个错误,不知何故(即使具有相同确切配置的不同机器产生如此不同的结果很奇怪),但我真的不知道将这种行为归因于什么

解决方法

您可能想尝试 tf.keras.backend.clear_session(),如 documentation 中所述,它解释了在循环中创建模型时内存消耗会发生什么:

如果您在循环中创建多个模型,此全局状态将随着时间的推移消耗越来越多的内存,您可能需要清除它。调用 clear_session() 释放全局状态:这有助于避免旧模型和层的混乱,尤其是在内存有限的情况下。

如果没有 clear_session(),这个循环的每次迭代都会略微增加 Keras 管理的全局状态的大小

for _ in range(100):
  model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])

在开始调用 clear_session() 的情况下,Keras 在每次迭代时都以空白状态开始,并且内存消耗随着时间的推移保持不变。

for _ in range(100):
  tf.keras.backend.clear_session()
  model = tf.keras.Sequential([tf.keras.layers.Dense(10) for _ in range(10)])
,

我没有确切的原因,它的主要原因是什么。看起来它们可能是某种内存泄漏。在你的代码中,我做了一些修改来运行它。试试这个:

import numpy as np
import tensorflow as tf
import time,gc

tf.config.run_functions_eagerly(False)

n_samples = 300000
n_features = 100
n_targets = 5
batch_size = 100
x = np.array(range(n_samples * n_features),dtype=np.float64).reshape((n_samples,n_features))
y = np.array(range(n_samples * n_targets),n_targets))

def get_model():
    inputs = tf.keras.Input(shape=(n_features,),name='input')
    outputs = tf.keras.layers.Dense(n_features,name='dense_1',activation=tf.keras.activations.relu)(inputs)
    outputs = tf.keras.layers.Dense(n_features,name='dense_2',activation=tf.keras.activations.relu)(outputs)
    outputs = tf.keras.layers.Dense(n_features,name='dense_3',name='dense_4',activation=tf.keras.activations.relu)(outputs)
    outputs = tf.keras.layers.Dense(n_targets,name='output',activation=tf.keras.activations.linear)(outputs)
    model = tf.keras.Model(inputs=inputs,outputs=outputs)
    return model

    
for t_idx in range(7):
    tf.keras.backend.clear_session()
    gc.collect()
    
    dataset = [x,y]
    dataset = tf.data.Dataset.from_tensor_slices(tuple(dataset)).shuffle(n_samples).repeat().batch(batch_size=batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    data_iterator = iter(dataset)
    
    model = get_model()
    trainable_variables = list(model.trainable_variables)
    adam_opt = tf.optimizers.Adam(learning_rate=0.001)
    
    @tf.function
    def loss(batch):
        x_,y_ = batch
        y_pred_ = model(x_)
        return tf.keras.losses.MSE(y_pred_,y_)

    @tf.function
    def optimization_step():
        batch = next(data_iterator)
        def f(): return loss(batch)
        adam_opt.minimize(f,var_list=trainable_variables)

    iterations = 50000
    loop_start = time.time()
    
    for idx in range(iterations):
        optimization_step()

    print(f'Elapsed: {time.time() - loop_start}')
    
    del model,dataset

当使用tf.config.run_functions_eagerly(False)时,它给出如下

Elapsed: 123.39986753463745
Elapsed: 121.13122296333313
Elapsed: 118.44610977172852
Elapsed: 116.83040761947632
Elapsed: 118.46350479125977
Elapsed: 118.59502696990967
Elapsed: 120.34098505973816

没有它给出如下

Elapsed: 114.46430349349976
Elapsed: 124.23725700378418
Elapsed: 126.14825916290283
Elapsed: 127.5985496044159
Elapsed: 126.79593586921692
Elapsed: 124.1206603050232
Elapsed: 125.85739850997925

虽然我认为这不是主要问题。可能还有更多的东西需要深入研究。

[我的操作系统:Windows 10,TF: 2.4.1,RTX:2070。]

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...