如何通过 pytorch-lightning 正确使用 Tensorboard 的 TSNE?

问题描述

我在 MNIST 上运行以下代码

也就是说,我从每个验证时期返回

return {"val_loss": loss,"recon_batch": recon_batch,"label_batch": label_batch,"label_img": orig_batch.view(-1,1,28,28)}

然后使用

    mat = torch.cat([o["recon_batch"] for o in outputs])
    Metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
    label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
    tb.add_embedding(
        mat=mat,Metadata=Metadata,label_img=label_img,global_step=self.current_epoch,)

并期望它能够工作,就像在 the doc 中一样。

似乎只显示一个批次,在验证期间我得到的日志如下

验证:92%|█████████▏| 49/53 [00:01

如何为所有时代获得 recon_batch 的有效 TSNE?


完整代码供参考:

def validation_step(self,batch,batch_idx):
    if self._config.dataset == "toy":
        (orig_batch,noisy_batch),label_batch = batch
        # Todo put in the noise here and not in the dataset?
    elif self._config.dataset == "mnist":
        orig_batch,label_batch = batch
        orig_batch = orig_batch.reshape(-1,28 * 28)
        noisy_batch = orig_batch
    else:
        raise ValueError("invalid dataset")

    noisy_batch = noisy_batch.view(noisy_batch.size(0),-1)

    recon_batch,mu,logvar = self.forward(noisy_batch)

    loss = self._loss_function(
        recon_batch,orig_batch,logvar,reconstruction_function=self._recon_function
    )

    tb = self.logger.experiment
    tb.add_scalars("losses",{"val_loss": loss},global_step=self.current_epoch)
    if batch_idx == len(self.val_DataLoader()) - 2:
        orig_batch -= orig_batch.min()
        orig_batch /= orig_batch.max()
        recon_batch -= recon_batch.min()
        recon_batch /= recon_batch.max()

        orig_grid = torchvision.utils.make_grid(orig_batch.view(-1,28))
        val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1,28))

        tb.add_image("original_val",orig_grid,global_step=self.current_epoch)
        tb.add_image("reconstruction_val",val_recon_grid,global_step=self.current_epoch)
        # f,axarr = plt.subplots(2,1)
        # axarr[0].imshow(orig_grid.permute(1,2,0).cpu())
        # axarr[1].imshow(val_recon_grid.permute(1,0).cpu())
        # plt.show()
        pass

    return {"val_loss": loss,28)}

def validation_epoch_end(self,outputs: List[Any]) -> None:
    first_batch_dict = outputs[-1]
    self.log(name="val_epoch_end",value={"val_loss": first_batch_dict["val_loss"]})

    tb = self.logger.experiment
    # assert mat.shape[0] == label_img.shape[0],'#images should equal with #data points'
    mat = torch.cat([o["recon_batch"] for o in outputs])
    Metadata = torch.cat([o["label_batch"] for o in outputs]).cpu()
    label_img = torch.cat([o["label_img"] for o in outputs]).cpu()
    tb.add_embedding(
        mat=mat,)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...