问题描述
我是 GMM 的新手,出于学习目的,我需要了解是否可以将通过 GMM 的聚类应用于具有缩减维度 (PCA) 的 CIFAR10 数据集。具体来说,我需要绘制数据集的密度分布。
这是我目前的代码。
# Load Dataset from disk
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
train_set,train_label = tfds.as_numpy(tfds.load(
'cifar10',split='train',batch_size=-1,as_supervised=True,))
test_set,test_label = tfds.as_numpy(tfds.load(
'cifar10',split='test',))
# normalize
norm_train = np.float32(train_set)/ 255.
norm_test = np.float32(test_set)/ 255.
n_images = train_set.shape[0]
indexes = np.random.randint(0,norm_train.shape[0],size=n_images)
images = norm_train[indexes]
# Center dataset for PCA
images = np.float32(images)
mu = np.mean(images)
images -= mu
images /= std
pca_input = np.reshape(images,(-1,n_images))
# Do PCA
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error
# Use the first 500 components
n = 500
pca_ = PCA(n_components=n)
principalComponents = pca_.fit_transform(pca_input)
recon = pca_.inverse_transform(principalComponents)
mse = mean_squared_error(pca_input,recon,squared=True)
print(f"MSE: {mse} with {n} components")
# Attempt on GMM
from sklearn.mixture import GaussianMixture as GMM
gmm_train = pca_.components_.T
n = 10 # Use 10 clusters representing the 10 classes
gmm_model = GMM(n_components=n,covariance_type='full')
gmm_input = pca_.components_.T # Reduced components of the dataset
gmm = gmm_model.fit(gmm_input)
我正在尝试将 this 应用到 CIFAR10 并查看使用 PCA-GMM 的聚类是否有效。
我试图构建一个网格,但出现尺寸不匹配的错误,在这种情况下是预期特征的数量。
下面的代码产生此错误:
ValueError: Expected the input data X have 500 features,but got 2 features
x = np.linspace(np.min(gmm_input),np.max(gmm_input),gmm_input.shape[0])
y = np.linspace(np.min(gmm_input),gmm_input.shape[0])
X,Y = np.meshgrid(x,y)
XX = np.array([X.ravel(),Y.ravel()]).T
Z = -gmm_model.score_samples(XX)
Z = Z.reshape(X.shape)
CS = plt.contour(X,Y,Z,norm=Lognorm(vmin=1.0,vmax=1000.0),levels=np.logspace(0,3,10))
CB = plt.colorbar(CS,shrink=0.8,extend='both')
虽然我知道我需要匹配特征的数量,但我不知道如何匹配。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)