问题描述
我创建了一个类,该类在线程内部的循环中创建,训练和评估预测变量。我使用while循环运行线程,在该线程中创建该类以执行训练然后进行预测。运行循环后,最后删除预测变量类实例。 我的期望是,当我执行del predictor
并跟进gc.collect()
之后,实例中的tf图也应删除,并应从RAM内存中删除。但是每个循环的RAM使用量都在不断增加。
我正在使用tf 1.14.0
和keras 2.2.4
import tensorflow as tf
from keras import backend as K
from keras.models import Model
from keras.layers import Input,Dense,LSTM,Reshape
from keras.callbacks import ModelCheckpoint,EarlyStopping,ReduceLROnPlateau,TensorBoard
import numpy as np
from multiprocessing import Event
from threading import Thread
import gc
class nn_model():
"""
Creates the data driven model for predicting energy
"""
def __init__(self,*args,**kwargs):
"""
Initiate the class
"""
self.graph = tf.Graph()
self.session = tf.Session(graph=self.graph)
def design_net(self,):
with self.graph.as_default(): # pylint: disable=not-context-manager
with self.session.as_default(): # pylint: disable=not-context-manager
input_layer = Input(batch_shape=(None,1,3))
layers = input_layer
layers = Dense(64,activation='relu')(layers)
layers = LSTM(32,activation='linear',return_sequences=True)(layers)
output = LSTM(1,return_sequences=False)(layers)
output = Reshape((1,1))(output)
self.model = Model(inputs=input_layer,outputs=output)
self.model.compile(loss=self.loss,optimizer='adam')
def fit(self,**kwargs):
"""
Fit the model to the data
"""
# train the model
with self.graph.as_default(): # pylint: disable=not-context-manager
with self.session.as_default(): # pylint: disable=not-context-manager
self.history = self.model.fit(kwargs['X_train'],kwargs['y_train'],validation_data=(kwargs['X_val'],kwargs['y_val']),epochs=50,callbacks=self.callbacks(),verbose=0)
def callbacks(self,**kwargs):
"""
Create callbacks
"""
self.modelchkpt = ModelCheckpoint('best_model',monitor = 'val_loss',save_best_only = True,period=2)
self.earlystopping = EarlyStopping(monitor = 'val_loss',patience=5,restore_best_weights=False)
self.reduclronplateau = ReduceLROnPlateau(monitor = 'val_loss',patience=2,cooldown = 3)
self.cb_list = [self.modelchkpt,self.earlystopping,self.reduclronplateau]
return self.cb_list
def predict(self,**kwargs):
"""
Evaluate the model; pass args,kwargs if needed
"""
with self.graph.as_default(): # pylint: disable=not-context-manager
with self.session.as_default(): # pylint: disable=not-context-manager
predictions = self.model.predict(kwargs['X_test'])
return predictions
def learn_on_loop(end_looping,X_train,y_train,X_val,y_val):
while not end_looping:
predictor = nn_model()
predictor.design_net()
predictor.fit(**{'X_train':X_train,'y_train':y_train,'X_val':X_val,'y_val':y_val})
np.save('y_pred.npy',predictor.predict(**{'X_val':X_val}))
del predictor
gc.collect()
if __name__ == "__main__":
end_looping = Event()
X_train,y_val = np.load('X_train.npy'),np.load('y_train.npy'),np.load('X_val.npy'),np.load('y_val.npy')
predictor_th = Thread(target=learn_on_loop,daemon=False,kwargs={
'end_looping':end_looping,'X_train':X_train,'y_val':y_val
})
predictor_th.start()
try:
predictor_th.join()
except KeyboardInterrupt:
end_looping.set()
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)