AssertionError:不要使用tf.reset_default_graph清除嵌套图如果需要清除图形

问题描述

这几乎把我杀了。

我试图在for循环中加载几个训练有素的模型,然后将它们放到内存中以节省加载时间。 然后,我进行了编码,将一个训练有素的模型从列表中排除出来进行预测,并在下面得到了此消息。 有没有人帮助我? 谢谢您的帮助。


  File "C:\Users\SONSANGWOO\Desktop\paperwork\ANN\standard_scaler_97\Testing_standard_scaler_97_test.py",line 242,in <module>
    K.clear_session()

  File "C:\Users\SONSANGWOO\anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py",line 414,in clear_session
    tf_keras_backend.clear_session()

  File "C:\Users\SONSANGWOO\anaconda3\lib\site-packages\tensorflow\python\keras\backend.py",line 339,in clear_session
    ops.reset_default_graph()

  File "C:\Users\SONSANGWOO\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py",line 5588,in reset_default_graph
    raise AssertionError("Do not use tf.reset_default_graph() to clear "

AssertionError: Do not use tf.reset_default_graph() to clear nested graphs. If you need a cleared graph,exit the nesting and create a new graph.

这是我的代码

start3 = time.time() #set a starting timestart = time.time() #set a starting time
print("Prediction is going on!")
# prediction (load model -> prediction -> saveing the list of prediction)
listofresult = list()
listofloading_time = list()
listofmodel = list()
start4 = time.time() #set a starting timestart = time.time() #set a starting time
for n in range(n_divisions):
    K.clear_session()
    model_p = load_model('./model_candidates_97/best_{}.h5'.format('model_' + str(n + 1)),custom_objects ={'<lambda>': tf.nn.leaky_relu},compile=False)
    listofmodel.append(model_p)
listofmodel = listofmodel
loading_time = (time.time() - start4)

start5 = time.time() #set a starting timestart = time.time() #set a starting time

for n in range(n_divisions): #n_divisions`enter code here`
    
    with graph.as_default():

        K.clear_session()
        
        model_p_list = listofmodel[n]

    
        yhat = model_p_list.predict(X__p,verbose=0)
    
        prediction = np.array(yhat)
    
        result = prediction.reshape(prediction.shape[0],prediction.shape[1])    
    
        listofresult.extend(result)
    
        listofloading_time.append(loading_time)

        update_progress(n/(n_divisions))

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)