问题描述
我创建了一个散点图,类似于可变自动编码器(VAE)的 mnist 散点图。
图的生成方式如下(z_mean的维数为(?,2)
):
plt.scatter(z_mean[:,0],z_mean[:,1],c=y_test)
现在,我想绘制属于同一类的所有图的均值。 例如,如果我绘制了10个类别的10000个测试样本,我想绘制所有属于同一类别的图的平均值。因此,只需绘制10个点,每个点表示一个类(从0到9)即可。
解决方法
您可以计算每个y值的均值并将其存储到数组中。
以下代码首先创建一些随机测试数据,显示具有一定透明度的散点图,计算每组的均值,然后使用大点显示这些均值。它还使用“ tab10”作为颜色图,其颜色比默认的“ viridis”更鲜明。
import numpy as np
import matplotlib.pyplot as plt
N = 100
M = 10
z_mean = np.random.normal(np.tile(np.random.uniform(1,10,2 * M),N)).reshape(-1,2)
y_test = np.tile(np.arange(M),N)
plt.scatter(z_mean[:,0],z_mean[:,1],c=y_test,cmap='tab10',alpha=.3)
ys = np.unique(y_test)
means = np.array([np.mean(z_mean[y_test == y,:],axis=0) for y in ys])
plt.scatter(means[:,means[:,c=ys,s=200,edgecolors='black',cmap='tab10')
plt.colorbar()
plt.show()