问题描述
我为我的 3 个高斯数据集的混合实现了期望 - 最大化算法,但没有使用“from sklearn.mixture import GaussianMixture”。我也发现了均值和协方差值,但是在收敛后尝试绘制集群分配(具有不同颜色)时失败了,我需要将绘图函数添加到我的 GMM 类中,但我无法实现这一点,当我遇到不同的错误时尝试不同的方法,有人可以帮我吗?在我的模型收敛后,我需要使用 mean 、 sigma 、 phi 值来绘制集群分配...
这是我的均值、sigma 和 phi(又名高斯大小)结果:
gmm.mean:
matrix([[ 6.64610246,6.44711899],[-0.14003791,0.56576422],[ 1.6167569,0.65034397]])
gmm.sigma:
array([[[10.86679748,7.45436406],[ 7.45436406,7.89496679]],[[ 1.73652337,-0.04891332],[-0.04891332,0.55221408]],[[ 0.92895187,-0.32794522],[-0.32794522,0.56913115]]])
gmm.phi:
array([0.70564128,0.15347464,0.14088408])
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
%matplotlib inline
class GMM(object):
def __init__(self,X,k=3):
# dimension
X = np.asarray(X)
self.m,self.n = X.shape
self.data = X.copy()
# number of mixtures
self.k = k
def _init(self):
# init mixture means/sigmas
self.mean_arr = np.asmatrix(np.random.random((self.k,self.n)))
self.sigma_arr = np.array([np.asmatrix(np.identity(self.n)) for i in range(self.k)])
self.phi = np.ones(self.k)/self.k
self.w = np.asmatrix(np.empty((self.m,self.k),dtype=float))
#print(self.mean_arr)
#print(self.sigma_arr)
def fit(self,tol=1e-4):
self._init()
num_iters = 0
ll = 1
prevIoUs_ll = 0
while(ll-prevIoUs_ll > tol):
prevIoUs_ll = self.loglikelihood()
self._fit()
num_iters += 1
ll = self.loglikelihood()
print('Iteration %d: log-likelihood is %.6f'%(num_iters,ll))
print('Terminate at %d-th iteration:log-likelihood is %.6f'%(num_iters,ll))
def loglikelihood(self):
ll = 0
for i in range(self.m):
tmp = 0
for j in range(self.k):
#print(self.sigma_arr[j])
tmp += sp.stats.multivariate_normal.pdf(self.data[i,:],self.mean_arr[j,:].A1,self.sigma_arr[j,:]) *\
self.phi[j]
ll += np.log(tmp)
return ll
def _fit(self):
self.e_step()
self.m_step()
def e_step(self):
# calculate w_j^{(i)}
for i in range(self.m):
den = 0
for j in range(self.k):
num = sp.stats.multivariate_normal.pdf(self.data[i,self.mean_arr[j].A1,self.sigma_arr[j]) *\
self.phi[j]
den += num
self.w[i,j] = num
self.w[i,:] /= den
assert self.w[i,:].sum() - 1 < 1e-4
def m_step(self):
for j in range(self.k):
const = self.w[:,j].sum()
self.phi[j] = 1/self.m * const
_mu_j = np.zeros(self.n)
_sigma_j = np.zeros((self.n,self.n))
for i in range(self.m):
_mu_j += (self.data[i,:] * self.w[i,j])
_sigma_j += self.w[i,j] * ((self.data[i,:] - self.mean_arr[j,:]).T * (self.data[i,:]))
#print((self.data[i,:]))
self.mean_arr[j] = _mu_j / const
self.sigma_arr[j] = _sigma_j / const
#print(self.sigma_arr)
然后您可以运行以下命令以获取结果:
X = np.load('dataset.npy')
X.shape
gmm = GMM(X)
gmm.fit()
gmm.mean_arr
gmm.sigma_arr
gmm.phi
我的 X 是:
[[ 4.26555433e+00 3.21024579e+00]
[ 8.81761228e+00 9.86718640e+00]
[ 4.41543618e+00 4.22265208e+00]
[-3.68363931e-01 5.13191991e-01]
[ 8.83806033e+00 9.49401994e+00]
[ 7.67936393e+00 9.66004924e+00]
[ 9.60085462e-01 2.52018203e+00]
[ 9.48979390e+00 8.34650617e+00]
[-1.98198674e+00 3.78513885e-01]
[-1.92968889e+00 3.52613969e-02]
[ 4.41356495e+00 3.77543273e+00]
[-1.21022618e+00 9.26792166e-01]
[ 5.72291741e+00 4.34389425e+00]
[ 6.49908168e+00 5.05117749e+00]
[ 8.81199730e+00 8.69221815e+00]
[-1.04290858e+00 1.11563874e+00]
[ 4.76292352e+00 3.81392694e+00]
[ 1.09917009e+01 8.20098503e+00]
[-5.41372600e-01 2.53994061e-01]
[ 7.51685369e+00 1.04303631e+01]
[ 5.11096931e-01 -1.11722204e-01]
[ 1.09768856e+01 9.40961836e+00]
[ 3.80754300e+00 4.24268539e+00]
[ 1.14809672e+00 1.96133604e-01]
[ 9.13365646e-01 1.07153243e+00]
[ 1.05310144e+01 9.36354718e+00]
[ 5.68329933e+00 4.31274985e+00]
[-4.37161502e-01 7.71881867e-01]
[ 3.17248172e+00 4.81192547e+00]
[ 6.66097921e+00 9.58794805e+00]
[ 3.99195683e+00 4.06491044e+00]
[ 6.22531929e+00 5.15769022e+00]
[ 1.96320967e+00 -3.08888136e-02]
[ 4.17966391e-01 1.42152696e+00]
[ 2.19619015e+00 6.80947066e-01]
[ 1.09676683e+01 8.82831035e+00]
[ 6.31064883e+00 4.89614185e+00]
[ 4.98529659e+00 3.40819359e+00]
[ 1.86284412e-01 5.22104703e-01]
[ 9.90470279e+00 8.69623216e+00]
[ 2.83957912e+00 -6.12682946e-03]
[ 8.71491284e-02 6.06743766e-01]
[ 1.04342680e+00 1.42283140e-01]
[-3.74280561e+00 1.54204389e-01]
[ 1.14760051e+00 3.76256123e-01]
[ 8.42656755e+00 9.82636153e+00]
[-7.54666096e-01 -1.15765296e+00]
[ 1.10159056e+01 8.60101641e+00]
[ 2.64179923e+00 4.34958382e+00]
[ 9.24632799e+00 1.04522140e+01]
[ 6.81220311e+00 2.57877922e+00]
[ 1.09396362e+01 8.46709485e+00]
[ 1.25448516e+01 7.48599827e+00]
[ 1.08263103e+01 7.55723936e+00]
[ 1.20008410e+01 7.58448909e+00]
[ 4.13533655e+00 5.37192873e+00]
[ 2.54984916e+00 6.49562290e-01]
[ 4.06723307e+00 5.13787115e+00]
[ 2.32700836e+00 8.03748523e-01]
[ 7.57363297e+00 1.11520151e+01]
[-4.00180965e-01 1.64078380e+00]
[ 6.04436025e+00 9.56398642e+00]
[ 8.35795466e+00 9.83648542e+00]
[-1.79871724e+00 1.83286134e+00]
[-9.09222281e-01 1.45233791e+00]
[ 2.38665530e+00 1.25116989e+00]
[ 8.03552258e+00 9.52642714e+00]
[ 1.79232243e+00 4.11493465e-01]
[ 2.54783339e+00 4.11223663e+00]
[ 8.79163087e+00 1.00479530e+01]
[ 4.18734899e+00 5.56972080e+00]
[ 9.52914968e+00 8.69759961e+00]
[ 1.01973587e+01 9.52744569e+00]
[ 8.17840429e+00 1.00499742e+01]
[ 1.21625095e+01 9.12341710e+00]
[ 8.64879859e+00 9.22217024e+00]
[ 6.08532468e+00 5.49757049e+00]
[ 6.21846361e+00 4.41768001e+00]
[-1.36275580e+00 1.18683462e+00]
[ 3.74812444e+00 5.54985857e+00]
[ 1.18918369e+00 -5.64789491e-01]
[-2.85826034e-01 6.00118823e-01]
[ 9.46511887e+00 1.00844761e+01]
[ 1.83493301e+00 1.83997442e-01]
[ 8.73175554e+00 8.97661000e+00]
[ 2.85593323e+00 5.29043685e+00]
[ 7.80986531e+00 3.48094883e+00]
[ 1.58212581e+00 6.10855579e-01]
[ 1.04212728e+01 8.48889791e+00]
[ 3.81704323e+00 4.71847967e+00]
[ 2.03502226e+00 -9.42890059e-03]
[-4.17499326e-01 6.19846114e-01]
[ 1.70936199e+00 5.78953056e-01]
[ 1.42521995e+00 -1.85981029e-01]
[ 9.88971787e+00 8.03016106e+00]
[ 1.59457226e+00 -6.26462391e-03]
[ 6.09829665e+00 4.00624288e+00]
[ 8.76239511e+00 4.77053822e+00]
[ 6.70449518e+00 1.05150775e+01]
[ 5.66536390e-01 4.18474877e+00]
[ 3.15777519e+00 6.23124274e-01]
[ 1.03724661e+01 7.52777602e+00]
[-7.74913392e-01 8.28572283e-01]
[ 5.48793852e-01 1.36953570e+00]
[ 8.33377016e+00 8.15383787e+00]
[ 1.08182133e+01 9.57596239e+00]
[ 4.89879512e-01 1.45321617e+00]
[ 1.50720393e-01 8.75142681e-01]
[ 3.99009796e+00 5.28686806e+00]
[-7.54695219e-01 9.28244657e-01]
[ 3.22281101e+00 3.80406773e+00]
[ 9.63994240e+00 9.26402307e+00]
[ 5.74277199e-01 8.45388567e-01]
[ 2.53768281e+00 6.69658731e-01]
[ 2.84912409e+00 2.95825492e+00]
[ 3.68112677e+00 5.42654100e+00]
[-8.54105140e-01 1.45307947e-01]
[ 2.08143788e-01 5.29727332e-01]
[ 1.08175848e+01 9.19557481e+00]
[ 1.23837091e+01 9.24367756e+00]
[ 7.70940947e+00 1.00174566e+01]
[ 1.20073878e+01 8.28556817e+00]
[ 1.02865158e+00 4.54454040e+00]
[ 3.23484087e+00 4.14596179e+00]
[ 1.71501609e+00 -6.23605664e-01]
[ 4.02304019e+00 5.44749368e+00]
[ 1.95696398e-01 1.93745828e+00]
[ 5.06935138e+00 3.98881882e+00]
[ 1.01292235e+01 9.39419820e+00]
[-2.06650385e+00 -6.24012538e-01]
[ 2.96366558e+00 4.79472501e+00]
[ 2.15822240e+00 6.36036913e-01]
[ 1.17625031e+00 1.58130199e+00]
[-1.15373789e-01 -6.45838458e-01]
[ 8.63144914e-01 1.79461536e-01]
[ 1.05752670e+01 9.78830615e+00]
[ 1.14179845e+01 1.00002874e+01]
[-4.07395396e-01 1.84782000e+00]
[ 1.33936146e-01 7.74535193e-01]
[ 8.85461386e-01 3.08315409e-01]
[ 7.07608487e+00 1.03087883e+01]
[ 5.44575573e+00 3.64482067e+00]
[ 6.86595421e-01 -6.62564748e-01]
[ 3.81348961e+00 3.84796905e+00]
[ 1.06206602e+01 8.23485574e+00]
[ 2.59351909e+00 3.71681774e+00]
[ 1.08223650e+01 8.64833422e+00]
[ 1.42056415e+00 1.57964939e+00]
[ 3.58386017e+00 4.65330522e+00]
[ 9.92007963e+00 9.07292954e+00]
[ 3.45296848e+00 3.97780252e+00]
[ 2.17594024e+00 1.75095942e-01]
[ 4.78701452e+00 2.81095239e+00]
[ 5.77216897e+00 3.26754694e+00]
[ 1.02803103e+01 9.83331246e+00]
[-8.13893383e-02 2.34455257e+00]
[ 5.68531037e+00 3.41488624e+00]
[ 1.54680319e+00 4.96658225e+00]
[ 1.03612530e+01 8.37139059e+00]
[ 7.10560773e-01 2.21029718e+00]
[ 1.03590332e+01 8.39057628e+00]
[ 5.18982402e-01 1.19967651e+00]
[ 1.06176324e+01 8.06513920e+00]
[-1.20184236e-01 6.55499782e-01]
[ 4.34286141e+00 3.41298363e+00]
[ 5.22644879e+00 4.09507681e+00]
[ 2.61087183e+00 5.94263424e+00]
[ 1.32877048e+00 -1.75509467e-01]
[-1.06464461e+00 3.93116823e-01]
[ 4.46520373e+00 4.92162509e+00]
[ 4.26688063e+00 3.82543087e+00]
[ 3.60059080e+00 4.79495178e+00]
[ 2.19491471e+00 -3.27189229e-01]
[ 9.66680087e+00 8.21816238e+00]
[ 6.73590028e+00 4.45599688e+00]
[ 8.65623285e+00 9.68635473e+00]
[ 1.19035424e+01 8.90803229e+00]
[ 4.85763264e+00 5.71543969e+00]
[ 8.48589895e+00 9.89286260e+00]
[ 7.53815671e+00 1.15632634e+01]
[ 8.49158632e-01 9.57054571e-01]
[ 4.65752544e+00 3.51654423e+00]
[ 7.48995404e+00 9.30216335e+00]
[ 7.28081736e+00 4.78077929e+00]
[ 3.67578666e+00 4.15162963e+00]
[ 8.08577389e+00 8.94865350e+00]
[ 5.44218600e+00 4.27403957e+00]
[ 7.92727577e+00 9.08631366e+00]
[ 9.01850012e+00 9.31847799e+00]
[-7.02575092e-01 1.34242943e+00]
[ 6.25001597e+00 2.64143613e+00]
[ 7.40480317e+00 9.07743497e+00]
[ 2.30332490e+00 9.77438340e-01]
[ 9.36878048e+00 7.77607427e+00]
[ 4.22227284e+00 3.34618963e+00]
[ 3.46464886e+00 5.41950642e+00]
[ 3.33752146e+00 4.73535724e+00]
[ 1.20781115e+01 9.35190748e+00]
[-1.48034500e+00 -4.04821577e-01]
[ 7.67874458e+00 1.00482565e+01]
[ 2.52641776e-01 -1.15089951e+00]
[ 3.89248354e+00 4.77677886e+00]
[ 4.01307566e+00 -1.34161544e+00]
[ 3.06765747e+00 5.44929204e+00]
[ 3.71783013e+00 6.48753375e+00]
[ 9.70309891e+00 9.88674067e+00]
[ 1.20270515e+01 7.11825252e+00]
[ 5.55782927e+00 4.38940297e+00]
[ 1.56553858e+00 2.32981697e+00]
[ 4.47772890e+00 3.92569671e+00]
[ 8.05420107e+00 1.09281576e+01]
[ 7.98557520e+00 9.36668519e+00]
[ 1.49221877e+00 1.08003354e+00]
[ 4.48616166e+00 4.24168209e+00]
[ 7.10230667e+00 4.97862647e+00]
[ 1.00816523e+01 8.51087431e+00]
[-2.56573053e+00 6.17517739e-01]
[ 9.15823921e+00 9.52582924e+00]
[ 2.23548349e+00 4.56227172e+00]
[ 7.69918419e-01 1.78551125e+00]
[ 5.93181782e+00 4.71020734e+00]
[ 1.49423458e+00 1.19730995e+00]
[ 1.10829390e+01 9.83702720e+00]
[ 2.52760659e+00 3.29250129e+00]
[ 1.69752160e+00 3.58768779e+00]
[ 5.04026536e+00 3.84931324e+00]
[ 2.58270988e+00 -1.82650586e-01]
[ 7.72019245e+00 3.72913467e+00]
[ 1.04131732e+01 8.21158198e+00]
[ 6.22962897e+00 4.67524150e+00]
[ 5.23640606e+00 3.89967606e+00]
[ 1.13162453e+00 5.63923881e+00]
[ 5.19394450e+00 3.98823827e+00]
[ 2.45884435e+00 4.64253057e+00]
[ 1.56661083e+00 -6.03363169e-01]
[ 9.19558228e+00 1.02944795e+01]
[ 9.40178766e+00 8.54665019e+00]
[ 9.73842218e+00 1.15861484e+01]
[ 1.95433400e+00 1.31804940e+00]
[ 8.32855554e+00 8.97412497e+00]
[ 3.38108205e+00 2.90282131e-01]
[ 1.96327073e+00 4.73505828e+00]
[ 9.83885025e+00 9.11524058e+00]
[ 8.50117957e+00 9.51565194e+00]
[ 1.59361419e+00 8.97407052e-01]
[ 5.12614219e+00 5.52323652e+00]
[ 8.72208118e+00 8.14683151e+00]
[ 5.49741545e+00 4.38845158e+00]
[ 2.74603427e+00 1.26234557e+00]
[ 5.33927262e-01 8.86132685e-01]
[ 1.08331028e+01 8.66435121e+00]
[ 2.94623051e+00 5.46887909e+00]
[ 1.01488622e+01 9.18532193e+00]
[ 3.79393001e+00 4.89652927e+00]
[ 1.20836818e+00 1.73361919e+00]
[ 1.03225652e+01 7.75857553e+00]
[ 1.01308018e+01 7.38193342e+00]
[ 2.45959321e+00 3.98485827e-01]
[ 9.56628275e-01 1.94404474e+00]
[ 7.23829377e+00 4.74308915e+00]
[ 8.04754974e+00 4.24559518e+00]
[ 1.01860285e+01 8.42255662e+00]
[ 1.17184895e+00 7.19628765e-01]
[ 5.10741707e+00 3.38787566e+00]
[ 3.72699169e+00 4.58681604e+00]
[-2.22223809e-01 6.64162775e-01]
[ 8.33334024e-01 9.71853331e-01]
[ 8.90887591e-01 -8.21656127e-01]
[ 2.54971650e+00 -5.70526622e-02]
[ 4.17801377e+00 4.26268188e+00]
[ 3.00649443e+00 3.81460187e+00]
[ 1.12914483e+01 1.04322460e+01]
[-1.82955327e+00 1.06414016e+00]
[ 1.56756297e+00 3.91615304e+00]
[ 8.81617446e+00 1.02933993e+01]
[ 9.51666723e+00 9.11983616e+00]
[ 2.40125500e+00 3.34083275e+00]
[ 1.16635008e+00 1.88611393e+00]
[ 3.37421839e+00 3.67412357e+00]
[ 6.09528330e-01 7.76040583e-01]
[ 4.48331616e+00 4.06025425e+00]
[ 4.32460655e+00 5.46087064e+00]
[ 1.01694561e+01 8.70499109e+00]
[ 3.53764929e+00 3.31280405e+00]
[ 1.11291257e+01 8.70832459e+00]
[ 2.25597312e+00 3.17495374e+00]
[ 8.99776788e+00 9.52188265e+00]
[ 4.49736361e+00 4.38241658e+00]
[ 9.70869172e+00 8.40789673e+00]
[ 7.62254366e+00 4.07997970e+00]
[ 8.57681210e+00 1.06543292e+01]
[ 8.81789222e+00 9.24772320e+00]
[ 2.48914947e+00 1.44552220e+00]
[ 9.21039763e+00 9.87400407e+00]
[-1.91720742e+00 1.06916487e+00]
[ 1.19076617e+01 8.11919161e+00]
[ 1.00361711e+01 8.30256107e+00]
[ 8.63406664e+00 9.55195351e+00]
[ 1.13661260e+01 8.97654040e+00]
[ 3.14093184e+00 6.29076677e-01]]
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)