绘制属于同一类别的散点的平均值

问题描述

我创建了一个散点图,类似于可变自动编码器(VAE) mnist 散点图。

MNIST Mean plot for VAE

图的生成方式如下(z_mean的维数为(?,2)):

plt.scatter(z_mean[:,0],z_mean[:,1],c=y_test)

现在,我想绘制属于同一类的所有图的均值。 例如,如果我绘制了10个类别的10000个测试样本,我想绘制所有属于同一类别的图的平均值。因此,只需绘制10个点,每个点表示一个类(从0到9)即可。

解决方法

您可以计算每个y值的均值并将其存储到数组中。

以下代码首先创建一些随机测试数据,显示具有一定透明度的散点图,计算每组的均值,然后使用大点显示这些均值。它还使用“ tab10”作为颜色图,其颜色比默认的“ viridis”更鲜明。

import numpy as np
import matplotlib.pyplot as plt

N = 100
M = 10
z_mean = np.random.normal(np.tile(np.random.uniform(1,10,2 * M),N)).reshape(-1,2)
y_test = np.tile(np.arange(M),N)

plt.scatter(z_mean[:,0],z_mean[:,1],c=y_test,cmap='tab10',alpha=.3)

ys = np.unique(y_test)
means = np.array([np.mean(z_mean[y_test == y,:],axis=0) for y in ys])
plt.scatter(means[:,means[:,c=ys,s=200,edgecolors='black',cmap='tab10')
plt.colorbar()
plt.show()

example plot