pytorch闪电模型的输出预测

问题描述

这可能是一个非常简单的问题。我刚开始使用 PyTorch 闪电,无法弄清楚如何在训练后接收模型的输出

我对 y_train 和 y_test 作为某种数组(稍后步骤中的 PyTorch 张量或 NumPy 数组)的预测感兴趣,以使用不同的脚本在标签旁边绘图。

dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset,**train_params)
val_generator = torch.utils.data.DataLoader(val_dataset,**val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs,logger=logger,progress_bar_refresh_rate=20,callbacks=[early_stop_callback],num_sanity_val_steps=0)
trainer.fit(mynet)

在我的闪电模块中,我有以下功能

def __init__(self,random_inputs):

def forward(self,x):

def train_DataLoader(self):
    
def val_DataLoader(self):

def training_step(self,batch,batch_nb):

def training_epoch_end(self,outputs):

def validation_step(self,batch_nb):

def validation_epoch_end(self,outputs):

def configure_optimizers(self):

我是否需要特定的预测功能,或者是否有任何我没有看到的已实现方式?

解决方法

训练器有一个 test 函数。您可能想查看来自 pytorch-lightning 的原始文档以了解更多详细信息:https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#testing

,

我不同意这些答案:OP 的问题似乎集中在他应该如何使用在闪电中训练的模型来获得预测一般,而不是针对训练管道中的特定步骤。在这种情况下,用户不需要去任何靠近 Trainer 对象的地方——那些不打算用于一般推理,因此上面的答案鼓励反模式(每次我们都随身携带一个训练器对象)想对将来阅读这些答案的任何人进行一些推断。

不使用 trainer,我们可以直接从已定义的 Lightning 模块获得预测:如果我有我的(经过训练的)Lightning 模块实例 model = Net(...) 然后使用该模型获得预测在输入 x 上只需调用 model(x) 即可实现(只要 forward 方法已在 Lightning 模块上实现/覆盖 - 这是必需的)。

相比之下,Trainer.predict() 通常不是使用经过训练的模型运行推理的预期方法。作为训练管道的一部分,Trainer API 提供了 tunefittest LightningModule 的方法,在我看来,predict 方法是为广告提供的作为不太“标准”的训练步骤的一部分,对单独的数据加载器进行 hoc 预测。

OP 的问题(我是否需要特定的预测功能,或者是否有任何我看不到的已实现方式?)暗示他们不熟悉forward() 方法在 PyTorch 中的工作方式,但询问是否已经存在他们看不到的推理方法。因此,完整的答案需要进一步解释 forward() 方法适合推理过程的位置:

model(x) 工作的原因是因为 Lightning 模块是 torch.nn.Module 的子类,它们实现了一个名为 __call__() 的魔术方法,这意味着我们可以像调用函数一样调用类实例. __call__() 依次调用 forward(),这就是我们需要在 Lightning 模块中覆盖该方法的原因。

注意。因为 forward 只是我们使用 model(x) 时调用的逻辑的一部分,所以总是建议使用 model(x) 而不是 model.forward(x) 进行推理,除非您有特定的理由偏离。

,

您也可以使用 predict 方法。这是文档中的示例。 https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html

class LitMNISTDreamer(LightningModule):

    def forward(self,z):
        imgs = self.decoder(z)
        return imgs

    def predict_step(self,batch,batch_idx: int,dataloader_idx: int = None):
        return self(batch)


model = LitMNISTDreamer()
trainer.predict(model,datamodule) 

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...