问题描述
The official doc 仅声明
>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1,1,0])
>>> preds = torch.tensor([0,0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds,target)
这并未展示如何在框架中使用指标。
def __init__(...):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
def validation_step(self,batch,batch_index):
...
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs,label_batch)
self.val_confusion.update(log_probs,label_batch)
self.log('validation_confusion_step',self.val_confusion,on_step=True,on_epoch=False)
def validation_step_end(self,outputs):
return outputs
def validation_epoch_end(self,outs):
self.log('validation_confusion_epoch',self.val_confusion.compute())
在第 0 个 epoch 之后,这给了
Traceback (most recent call last):
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py",line 521,in train
self.train_loop.run_training_epoch()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py",line 588,in run_training_epoch
self.trainer.run_evaluation(test_mode=False)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py",line 613,in run_evaluation
self.evaluation_loop.log_evaluation_step_metrics(output,batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py",line 346,in log_evaluation_step_metrics
self.__log_result_step_metrics(step_log_metrics,step_pbar_metrics,line 350,in __log_result_step_metrics
cached_batch_pbar_metrics,cached_batch_log_metrics = cached_results.update_logger_connector()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",line 378,in update_logger_connector
batch_log_metrics = self.get_latest_batch_log_metrics()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",line 418,in get_latest_batch_log_metrics
batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",line 414,in run_batch_from_func_name
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",in <listcomp>
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",line 122,in get_batch_log_metrics
return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",*args,**kwargs)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",line 115,in run_latest_batch_metrics_with_func_name
for dl_idx in range(self.num_DataLoaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",in <listcomp>
for dl_idx in range(self.num_DataLoaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py",line 100,in get_latest_from_func_name
results.update(func(*args,add_DataLoader_idx=add_DataLoader_idx,**kwargs))
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py",line 298,in get_batch_log_metrics
result[dl_key] = self[k]._forward_cache.detach()
AttributeError: 'nonetype' object has no attribute 'detach'
它确实在训练前通过了健全性验证检查。
失败发生在 validation_step_end
中的返回。对我来说意义不大。
完全相同的使用 mertics 的方法准确无误。
如何得到正确的混淆矩阵?
解决方法
您可以使用 self.logger.experiment.add_figure(*tag*,*figure*)
报告该数字。
变量 self.logger.experiment
实际上是一个 SummaryWriter
(来自 PyTorch,而不是 Lightning)。此类具有方法 add_figure
(documentation)。
您可以按如下方式使用它:(MNIST 示例)
def validation_step(self,batch,batch_idx):
x,y = batch
preds = self(x)
loss = F.nll_loss(preds,y)
return { 'loss': loss,'preds': preds,'target': y}
def validation_epoch_end(self,outputs):
preds = torch.cat([tmp['preds'] for tmp in outputs])
targets = torch.cat([tmp['target'] for tmp in outputs])
confusion_matrix = pl.metrics.functional.confusion_matrix(preds,targets,num_classes=10)
df_cm = pd.DataFrame(confusion_matrix.numpy(),index = range(10),columns=range(10))
plt.figure(figsize = (10,7))
fig_ = sns.heatmap(df_cm,annot=True,cmap='Spectral').get_figure()
plt.close(fig_)
self.logger.experiment.add_figure("Confusion matrix",fig_,self.current_epoch)
,
这花了很多时间才找到。
这是我能粘贴的最少的代码,它仍然可读和可重现。
我不想把整个模型数据集和参数放在这里,因为这个问题的读者对它们没有兴趣,只是噪音。
也就是说,这里是创建每个时期的混淆矩阵并在 Tensorboard 中显示所需的代码
这是一个单一的框架,例如:
outlook = win32com.client.Dispatch("Outlook.Application").GetNamespace("MAPI")
def getMailBody(msgFile):
msgText = outlook.OpenSharedItem(msgFile)
return msgText.Body
和培训师的电话
import pytorch_lightning as pl
import seaborn as sn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
def __init__(self,config,trained_vae,latent_dim):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
self.logger: Optional[TensorBoardLogger] = None
def forward(self,x):
...
return log_probs
def validation_step(self,batch_index):
if self._config.dataset == "mnist":
orig_batch,label_batch = batch
orig_batch = orig_batch.reshape(-1,28 * 28)
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs,label_batch)
self.val_confusion.update(log_probs,label_batch)
return {"loss": loss,"labels": label_batch}
def validation_step_end(self,outputs):
return outputs
def validation_epoch_end(self,outs):
tb = self.logger.experiment
# confusion matrix
conf_mat = self.val_confusion.compute().detach().cpu().numpy().astype(np.int)
df_cm = pd.DataFrame(
conf_mat,index=np.arange(self._config.n_clusters),columns=np.arange(self._config.n_clusters))
plt.figure()
sn.set(font_scale=1.2)
sn.heatmap(df_cm,annot_kws={"size": 16},fmt='d')
buf = io.BytesIO()
plt.savefig(buf,format='jpeg')
buf.seek(0)
im = Image.open(buf)
im = torchvision.transforms.ToTensor()(im)
tb.add_image("val_confusion_matrix",im,global_step=self.current_epoch)