Seaborn Confusion Matrix热图2种配色方案正确的对角线与错误的其余部分

问题描述

背景

在混淆矩阵中,对角线表示预测标签与正确标签匹配的情况。因此对角线是好的,而其他所有单元格都是坏的。为了弄清楚对于非专家而言,CM的优点和缺点,我想为对角线提供与其余颜色不同的颜色。我想通过 Python&Seaborn 实现此目标。

基本上,我正在尝试实现此问题在R(ggplot2 Heatmap 2 Different Color Schemes - Confusion Matrix: Matches in Different Color Scheme than Missclassifications)中的作用

带有热图的普通Seaborn混淆矩阵

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50,2,38],[7,43,32],[9,4,76]])

sns.heatmap(cf_matrix,annot=True,cmap='Blues')  # cmap='OrRd'

这张图片的结果是

Seaborn Confusion Matrix with colormap 'Blues'

目标

我想用例如cmap='OrRd'。所以我想象会有2个颜色条,对角线1个蓝色,其他单元格1个。优选地,两个颜色条的值都匹配(因此例如0-70而不是0-70和0-40)。 我将如何处理?

以下内容不是使用代码,而是使用照片编辑软件进行的:

Desired Confusion Matrix color scheme

解决方法

您可以在对mask=的调用中使用heatmap()选择要显示的单元格。对角和off_diagonal单元使用两个不同的蒙版,可以获得所需的输出:

import numpy as np
import seaborn as sns

cf_matrix = np.array([[50,2,38],[7,43,32],[9,4,76]])

vmin = np.min(cf_matrix)
vmax = np.max(cf_matrix)
off_diag_mask = np.eye(*cf_matrix.shape,dtype=bool)

fig = plt.figure()
sns.heatmap(cf_matrix,annot=True,mask=~off_diag_mask,cmap='Blues',vmin=vmin,vmax=vmax)
sns.heatmap(cf_matrix,mask=off_diag_mask,cmap='OrRd',vmax=vmax,cbar_kws=dict(ticks=[]))

enter image description here

如果想花哨的话,可以使用GridSpec创建轴以具有更好的布局:

将numpy导入为np 将seaborn导入为sns

fig = plt.figure()
gs0 = matplotlib.gridspec.GridSpec(1,width_ratios=[20,2],hspace=0.05)
gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(1,subplot_spec=gs0[1],hspace=0)

ax = fig.add_subplot(gs0[0])
cax1 = fig.add_subplot(gs00[0])
cax2 = fig.add_subplot(gs00[1])

sns.heatmap(cf_matrix,ax=ax,cbar_ax=cax2)
sns.heatmap(cf_matrix,cbar_ax=cax1,cbar_kws=dict(ticks=[]))

enter image description here

,

您可以先用颜色图“ OrRd”绘制热图,然后用颜色图“ Blues”将热图覆盖,并用NaN代替上下三角值,请参见以下示例:

def diagonal_heatmap(m):

    vmin = np.min(m)
    vmax = np.max(m)    
    
    sns.heatmap(cf_matrix,vmax=vmax)

    diag_nan = np.full_like(m,np.nan,dtype=float)
    np.fill_diagonal(diag_nan,np.diag(m))
    
    sns.heatmap(diag_nan,cbar_kws={'ticks':[]}) 




cf_matrix = np.array([[50,76]])

diagonal_heatmap(cf_matrix)