Tensorflow v1:如何保存数据集预测并在以后访问?

问题描述

我是硕士学位学生,而且是Tensorflow的新手。 对于我的论文项目,我需要在框架中更改并整合推荐系统算法。

原始代码的作用

原始代码是这样的:https://github.com/CRIPAC-DIG/A-PGNN与Tensorflow v1.15 代码以这种方式工作:

  1. 运行record.py,给定一个csv文件,此代码将为火车和测试集创建一些tfrecord。
  2. 运行train_last.py,此代码将加载数据,并对测试集进行训练和评估。

要做的集成将这些代码封装到一个框架中。该框架需要两个方法:fit和predict_next。为了配合,我做了原始代码在record.py和火车中的工作。数据集由框架给出。我还创建了测试集的tfrecord。

我现在需要做的是方法predict_next。 我需要创建与原始代码相同的东西,但只需要少量的数据。

原始代码中的测试集流程是这样的:

  • 给定csv文件,在record.py:184和record.py:97中,代码生成测试集的tfrecord。

writer = tf.python_io.TFRecordWriter(orgin_path + str(count)+'.tfrecord')

  • 在train_last.py:94中加载此测试集,并使用具有DatasetV1Adapter的model_last.py:333中的函数解析此文件
test_dataset = tf.data.TFRecordDataset(test_filenames)
test_dataset = test_dataset.map(parse_function_(opt.max_session))
  • 创建一个批处理和一个迭代器(train_last.py:105)
test_batch_padding_dataset = test_dataset.padded_batch(opt.batchSize,padded_shapes=padded_shape,drop_remainder=True)
test_iterator = test_batch_padding_dataset.make_initializable_iterator()
test_data = test_iterator.get_next()
  • 创建(我认为)张量模型(train_last.py:127)
with tf.variable_scope('model',reuse=True):
    test_loss,test_index = model.forward(test_data['A_in'],test_data['A_out'],test_data['all_node'],test_data['seq_alias'],test_data['seq_mask'],test_data['session_alias'],test_data['session_len'],test_data['session_mask'],test_data['tar'],test_data['user'],train=False)
  • 训练后,用
  • 进行评估(train_last.py:163,model_last.py:415)
index,test_loss_,tar,seq_length,sess_length = session.run([test_index,test_loss,test_data['session_len']])

我做了什么和我需要什么:

  • 综合培训 我需要:
  • 进行预测,例如,如果我有10个用户,每个会话有5个项目,则我需要为每个用户预测所有下一个项目并存储在某个地方
  • 创建一个方法predict_next,在这里我可以简单地返回上一点的预测。

问题:

  • **如何在TensorFlow中修改原始测试代码以具有所有预测?
  • 我在哪里可以定义session.run返回的内容

非常感谢,如果有不清楚的地方,随时告诉我。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...