Tensorflow Keras 模型结果不可重现

问题描述

这就是我的训练脚本的样子,只有重要的部分。

import tensorflow as tf,numpy,random,os
SEED: int = 43
# To ensure reproducibility of the results after persisting into disk
tf.keras.backend.manual_variable_initialization(value=True)
os.environ['PYTHONHASHSEED']: str = '0'
numpy.random.seed(seed=SEED)
tf.random.set_seed(seed=SEED)
random.seed(a=SEED)

# Training Data acquisition 

model: tf.keras.Model = tf.keras.Sequential([
    tf.keras.layers.Input(
        shape=X_train_transformed.shape[1:],name='input_layer'),tf.keras.layers.BatchNormalization(),tf.keras.layers.Dense(units=UNITS,activation='relu'),tf.keras.layers.Dense(units=CLASSES)])
model.save(filepath=MODEL_DIRECTORY)

这是推理脚本,在不同的 python 会话中。

import tensorflow as tf,os
SEED: int = 43
# To ensure reproducibility of the results after persisting into disk
tf.keras.backend.manual_variable_initialization(value=True)
os.environ['PYTHONHASHSEED']: str = '0'
numpy.random.seed(seed=SEED)
tf.random.set_seed(seed=SEED)
random.seed(a=SEED)

X_test_loaded:pandas.DataFrame=pandas.read_csv(filepath_or_buffer='../test_data/test.csv',index_col='Unnamed: 0')
X_test_transformed: pandas.DataFrame = pandas.DataFrame(
    data=X_test_loaded.apply(func=feature_transformation,axis=1).to_list())

for _ in range(5): # Five iterations necessary to get correct result
    model=tf.keras.models.load_model(filepath=MODEL_DIRECTORY)
    y_pred=model.predict(x=X_test_transformed) 

最后一个预测循环是必要的,因为 model.predict 方法似乎直到五次尝试才在测试集上给出正确的结果(即在训练期间观察到的结果)。模型本身加载的权重与训练期间观察到的权重相同,但似乎 predict 方法输出有一些周期性变化,所以每次使用它时,我必须调用它五次以确保一致性。我不知道为什么会这样。我应该采取什么纠正措施?尤其是,5 在这里似乎是一个完全神奇的数字。

如果很重要,请在 python 3.8、Ubuntu 20.04 上运行它(并打算使用相同的 docker 映像在生产中部署),使用以下库版本。

matplotlib==3.3.4
numpy==1.19.5
pandas==1.1.5
scikit-learn==0.24.2
tensorflow==2.5.0

编辑: 经过一些更多的实验,我意识到周期不是 5,但实际上没有这样的周期。它在加载和预测时不断给出随机结果,有时它确实给出了正确的结果。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...