如何保存张量

问题描述

我有一个包含 1000 个项目的数据集。在针对数据训练模型之前,我将数据标准化。

我现在想使用该模型进行预测。但是,据我所知,我需要将输入到我需要预测的模型中的输入进行标准化。为了执行此操作,我需要在训练时计算平均值和标准差。

虽然我可以将它打印到控制台,但如何“保存”它 - 以备后用?我试图了解此处有关如何保存训练数据标准化时使用的均值和标准差的程序 - 以便我可以在进行预测时再次使用它。

解决方法

我确定我们可以先通过以下方式得到张量的数组表示:

// tensor here is the tensor variable that contains the tensor
const tensorAsArray = tensor.arraySync()

然后,我们将它像任何其他字符串一样保存到一个文件中

fs.writeFile(myFilePath,JSON.stringify(tensorAsArray),'utf-8')

要读回它并将其用作张量,我们会做相反的事情:

const tensorAsArray = JSON.parse(fs.readFile(myFilePath,'utf-8'))
const tensor = tf.tensor(tensorAsArray)

这让我可以保存均值和标准差以备后用。