问题描述
存在27个类别的多分类问题。
y_predict=[0 0 0 20 26 21 21 26 ....]
y_true=[1 10 10 20 26 21 18 26 ...]
名为“ answer_vocabulary”的列表为每个索引存储了相应的27个单词。 answer_vocabulary = [0 1 10 11 2 3农业商业东部,居住在北部.....]
cm = confusion_matrix(y_true = y_true,y_pred = y_predict)
我对混淆矩阵的顺序感到困惑。它是按升序排列的吗?而且,如果我想用标签序列= [0 1 2 3 10 11居住在东北部的农业商品...]对混乱矩阵进行重新排序,该如何实现?
这是我尝试绘制混淆矩阵的函数。
def plot_confusion_matrix(cm,classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm,interpolation='nearest',cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks,rotation=45)
plt.yticks(tick_marks,classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:,np.newaxis]
print("normalized confusion matrix")
else:
print('Confusion matrix,without normalization')
print(cm)
thresh = cm.max() / 2.
for i,j in itertools.product(range(cm.shape[0]),range(cm.shape[1])):
plt.text(j,i,cm[i,j],horizontalalignment="center",color="white" if cm[i,j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
解决方法
sklearn的混淆矩阵不存储有关如何创建矩阵的信息(类排序和规范化):这意味着创建后必须立即使用混淆矩阵,否则信息将丢失。
默认情况下,sklearn.metrics.confusion_matrix(y_true,y_pred)按照类在y_true中出现的顺序创建矩阵。
如果您将此数据传递给 sklearn.metrix.confusion_matrix :
discord.ext.commands.errors.CommandInvokeError: Command raised an exception: DownloadError: ERROR: 'w' is not a valid URL. Set --default-search "ytsearch" (or run youtube-dl "ytsearch:w" ) to search YouTube
Scikit-leart将创建此混淆矩阵(省略零):
+--------+--------+
| y_true | y_pred |
+--------+--------+
| A | B |
| C | C |
| D | B |
| B | A |
+--------+--------+
它将返回此numpy矩阵给您:
+-----------+---+---+---+---+
| true\pred | A | C | D | B |
+-----------+---+---+---+---+
| A | | | | 1 |
| C | | 1 | | |
| D | | | | 1 |
| B | 1 | | | |
+-----------+---+---+---+---+
如果您要选择类或对其重新排序,则可以将'labels'参数传递给+---+---+---+---+
| 0 | 0 | 0 | 1 |
| 0 | 0 | 1 | 0 |
| 0 | 0 | 0 | 1 |
| 1 | 0 | 0 | 0 |
+---+---+---+---+
。
要重新排序:
confusion_matrix()
或者,如果您只想关注某些标签(如果您有很多标签,则很有用):
labels = ['D','C','B','A']
mat = confusion_matrix(true_y,pred_y,labels=labels)
另外,请查看sklearn.metrics.plot_confusion_matrix。对于小(
如果您有> 100个类,则将用白色绘制矩阵。
,生成的混淆矩阵中列/行的顺序与 sklearn.utils.unique_labels() 返回的相同,它提取“唯一标签的有序数组”。在confusion_matrix()
(main,git-hash 7e197fd)的source code中,感兴趣的行如下
if labels is None:
labels = unique_labels(y_true,y_pred)
else:
labels = np.asarray(labels)
此处,labels
是 confusion_matrix()
的可选参数,用于自行规定标签的排序/子集:
cm = confusion_matrix(true_y,labels=labels)
因此,如果 labels = [0,10,3]
,cm 将具有形状 (3,3),并且可以直接使用 labels
索引行/列。如果您了解熊猫:
import pandas as pd
cm = pd.DataFrame(cm,index=labels,columns=labels)
请注意,unique_labels()
的文档声明不支持混合类型的标签(数字和字符串)。在这种情况下,我建议使用 LabelEncoder
。这将使您无需维护自己的查找表。
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
y = encoder.fit_transform(y)
# y have now values between 0 and n_labels-1.
# Do some ops here...
...
# To convert back:
y_pred = encoder.inverse_transform(y_pred)
y = encoder.inverse_transform(y)
正如前面提到的 previous answer,plot_confusion_matrix()
可以方便地可视化混淆矩阵。