问题描述
以下代码拟合具有多项式核的 SVM 并绘制虹膜数据和决策边界。输入 X 使用数据的前 2 列,萼片长度和宽度。但是,我很难将第 3 列和第 4 列的输出再现为 X,即花瓣的长度和宽度。如何更改代码的绘图功能以使其正常工作?提前致谢。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC
iris= datasets.load_iris()
y= iris.target
#X= iris.data[:,:2] # sepal length and width
X= iris.data[:,2:] # I tried a different X but it Failed.
# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
# create a mesh to plot in
x_min,x_max = X[:,0].min() - 1,X[:,0].max() + 1
y_min,y_max = X[:,1].min() - 1,1].max() + 1
h = (x_max / x_min)/100
xx,yy = np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
plt.subplot(1,1,1)
Z = svm_mod.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx,yy,Z,cmap=plt.cm.Paired,alpha=0.8)
plt.scatter(X[:,0],1],c=y,cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(),xx.max())
plt.title(title)
plt.show()
svm= SVC(C= 10,kernel='poly',degree=2,coef0=1,max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))
错误:
import sys
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-57-4515f111e34d> in <module>()
2 svm= SVC(C= 10,max_iter=500000)
3 svm_mod= svm.fit(X,y)
----> 4 plotSVC('kernel='+ str('polynomial'))
<ipython-input-56-556d4a22026a> in plotSVC(title)
10 Z = svm_mod.predict(np.c_[xx.ravel(),yy.ravel()])
11 Z = Z.reshape(xx.shape)
---> 12 plt.contourf(xx,alpha=0.8)
13 plt.scatter(X[:,cmap=plt.cm.Paired)
14 plt.xlabel('Sepal length')
~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in contourf(*args,**kwargs)
2931 mplDeprecation)
2932 try:
-> 2933 ret = ax.contourf(*args,**kwargs)
2934 finally:
2935 ax._hold = washold
~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax,*args,**kwargs)
1853 "the Matplotlib list!)" % (label_namer,func.__name__),1854 RuntimeWarning,stacklevel=2)
-> 1855 return func(ax,**kwargs)
1856
1857 inner.__doc__ = _add_data_doc(inner.__doc__,~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in contourf(self,**kwargs)
6179 self.cla()
6180 kwargs['filled'] = True
-> 6181 contours = mcontour.QuadContourSet(self,**kwargs)
6182 self.autoscale_view()
6183 return contours
~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in __init__(self,ax,**kwargs)
844 self._transform = kwargs.pop('transform',None)
845
--> 846 kwargs = self._process_args(*args,**kwargs)
847 self._process_levels()
848
~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _process_args(self,**kwargs)
1414 self._corner_mask = mpl.rcParams['contour.corner_mask']
1415
-> 1416 x,y,z = self._contour_args(args,kwargs)
1417
1418 _mask = ma.getmask(z)
~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _contour_args(self,args,kwargs)
1472 args = args[1:]
1473 elif Nargs <= 4:
-> 1474 x,z = self._check_xyz(args[:3],kwargs)
1475 args = args[3:]
1476 else:
~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _check_xyz(self,kwargs)
1508 raise TypeError("Input z must be a 2D array.")
1509 elif z.shape[0] < 2 or z.shape[1] < 2:
-> 1510 raise TypeError("Input z must be at least a 2x2 array.")
1511 else:
1512 Ny,Nx = z.shape
TypeError: Input z must be at least a 2x2 array.
解决方法
工作代码
Connect-AzAccount -UseDeviceAuthentication
出:
原因:
在 import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC
iris= datasets.load_iris()
y= iris.target
#X= iris.data[:,:2] # sepal length and width
X= iris.data[:,2:] # I tried a different X but it failed.
# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
# create a mesh to plot in
x_min,x_max = X[:,0].min() - 1,X[:,0].max() + 1
y_min,y_max = X[:,1].min() - 1,1].max() + 1
h = (x_max - x_min)/100
xx,yy = np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
plt.subplot(1,1,1)
z = svm_mod.predict(np.c_[xx.ravel(),yy.ravel()])
z = z.reshape(xx.shape)
plt.contourf(xx,yy,z,cmap=plt.cm.Paired,alpha=0.8)
plt.scatter(X[:,0],1],c=y,cmap=plt.cm.Paired)
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(),xx.max())
plt.title(title)
plt.show()
pass
svm= SVC(C= 10,kernel='poly',degree=2,coef0=1,max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))
行的 inf
中除以零得到 h
需要h = (x_max / x_min)/100
我通过阅读规定的例外发现了这一点
类型错误:输入 z 必须至少为 2x2 数组。
然后倒回去看到h = (x_max - x_min)/100
的形状来自z
的形状,它依赖于xx
,它是h
,这没有意义,这然后就轻松解决了。
我认为您应该学习如何更好地使用调试器。