问题描述
我在TensorFlow中有一些代码,该代码采用基本模型,使用一些数据对其进行微调(训练),然后使用该模型将其他数据用于predict()
。所有这些都封装在模块的main()
方法中,并且可以正常工作。
但是,当我在不同的基本模型上循环运行此代码时,例如在7个基本模型之后,我最终得到一个OOM。这是预期的吗?我希望Python在每次main()
调用后都会清理。 TensorFlow不这样做吗?我该如何强制呢?
编辑:这是一条MWE,显示的不是OOM崩溃,而是增加了内存消耗:
import gc
import os
import numpy as np
import psutil
import tensorflow as tf
tf.get_logger().setLevel("ERROR") # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
(model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
history = model.fit(
x=(x := tf.zeros((1,*model.input.shape[1:]))),y=(y := tf.zeros((1,*model.output.shape[1:]))),verbose=0,)
prediction = model.predict(x)
_ = gc.collect()
# tf.keras.backend.clear_session()
print(f"RSS {i}: {process.memory_info().RSS >> 20} MB")
在我的计算机(cpu)上打印
RSS 0: 374 MB
RSS 1: 438 MB
RSS 2: 478 MB
RSS 3: 517 MB
RSS 4: 554 MB
RSS 5: 588 MB
RSS 6: 634 MB
RSS 7: 669 MB
RSS 8: 686 MB
RSS 9: 726 MB
...
RSS 30: 1386 MB
RSS 31: 1413 MB
RSS 32: 1445 MB
RSS 33: 1476 MB
RSS 34: 1506 MB
RSS 35: 1536 MB
RSS 36: 1568 MB
RSS 37: 1597 MB
RSS 38: 1630 MB
RSS 39: 1662 MB
...
如果没有评论tf.keras.backend.clear_session()
,那就更好了,但还不完善:
RSS 0: 374 MB
RSS 1: 420 MB
RSS 2: 418 MB
RSS 3: 450 MB
RSS 4: 447 MB
RSS 5: 469 MB
RSS 6: 469 MB
RSS 7: 475 MB
RSS 8: 487 MB
RSS 9: 494 MB
...
RSS 40: 519 MB
RSS 41: 516 MB
RSS 42: 517 MB
RSS 43: 520 MB
RSS 44: 519 MB
RSS 45: 519 MB
RSS 46: 521 MB
RSS 47: 517 MB
RSS 48: 521 MB
RSS 49: 521 MB
...
RSS 90: 531 MB
RSS 91: 531 MB
RSS 92: 531 MB
RSS 93: 531 MB
RSS 94: 532 MB
RSS 95: 532 MB
RSS 96: 533 MB
RSS 97: 534 MB
RSS 98: 533 MB
RSS 99: 533 MB
切换gc.collect()
和tf.keras.backend.clear_session()
的顺序也无济于事。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)