带有tf.data.dateset的tf.compat.v1.disable_eager_execution

问题描述

我正在使用tensorflow 2.2。我有两个传递给tf.data.dataset.from_tensor_slices()的numpy数组(功能标签):

train_dataset = tf.data.Dataset.from_tensors(feature_train_slice,label_train_slice).shuffle(buffer_size).reapeat()

test_dataset = tf.data.Dataset.from_tensors(feature_test_slice,label_test_slice).shuffle(buffer_size).repeat()

我正在尝试将此数据传递给我的model.fit():

history = self.model.fit(ds_train,steps_per_epoch=int(train_steps / (batch_size)),verbose=1,epochs=epochs,callbacks=self.call_back(),use_multiprocessing=True,validation_data = test_dataset,validation_steps = int(validation_steps / (batch_size))
                            )

我用过

tf.compat.v1.disable_eager_execution()

在我的代码开头。如果我将其注释掉,则培训没有问题,但是我意识到的培训速度较慢(在2080TI上,每个步骤需要2秒钟)。如果我离开它,每个步骤大约需要1.2秒。但是,该程序永远不会越过该行

train_dataset = tf.data.Dataset.from_tensors(feature_train_slice,label_train_slice).shuffle().reapeat()

我离开程序超过30分钟,虽然消耗了大约60GB(我的ram是64GB),但该程序似乎无能为力。有没有人见过这个?欢迎任何帮助。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)