在TensorFlow中如何避免重复训练和预测中的OOM错误?

问题描述

我在TensorFlow中有一些代码,该代码采用基本模型,使用一些数据对其进行微调(训练),然后使用该模型将其他数据用于predict()。所有这些都封装在模块的main()方法中,并且可以正常工作。

但是,当我在不同的基本模型上循环运行此代码时,例如在7个基本模型之后,我最终得到一个OOM。这是预期的吗?我希望Python在每次main()调用后都会清理。 TensorFlow不这样做吗?我该如何强制呢?

编辑:这是一条MWE,显示的不是OOM崩溃,而是增加了内存消耗:

import gc
import os

import numpy as np
import psutil
import tensorflow as tf

tf.get_logger().setLevel("ERROR")  # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
    (model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
    history = model.fit(
        x=(x := tf.zeros((1,*model.input.shape[1:]))),y=(y := tf.zeros((1,*model.output.shape[1:]))),verbose=0,)
    prediction = model.predict(x)
    _ = gc.collect()
    # tf.keras.backend.clear_session()
    print(f"rss {i}: {process.memory_info().rss >> 20} MB")

在我的计算机(CPU)上打印

rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...

如果没有评论tf.keras.backend.clear_session(),那就更好了,但还不完善:

rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB

切换gc.collect()tf.keras.backend.clear_session()的顺序也无济于事。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...