【问题标题】:Plot SVM decision boundary绘制 SVM 决策边界
【发布时间】:2021-07-30 22:53:57
【问题描述】:

以下代码拟合具有多项式内核的 SVM,并绘制虹膜数据和决策边界。输入 X 使用数据的前 2 列,即萼片长度和宽度。但是,我很难将第 3 列和第 4 列的输出复制为 X,即花瓣的长度和宽度。如何更改绘图功能以使代码正常工作?提前致谢。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.svm import SVC

iris= datasets.load_iris()
y= iris.target
#X= iris.data[:, :2]  # sepal length and width
X= iris.data[:, 2:]   # I tried a different X but it failed.

# Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
def plotSVC(title):
    # create a mesh to plot in
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    h = (x_max / x_min)/100
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    plt.subplot(1, 1, 1)
    Z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
    plt.xlabel('Sepal length')
    plt.ylabel('Sepal width')
    plt.xlim(xx.min(), xx.max())
    plt.title(title)
    plt.show()

svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
svm_mod= svm.fit(X,y)
plotSVC('kernel='+ str('polynomial'))

错误:

  import sys
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-57-4515f111e34d> in <module>()
      2 svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
      3 svm_mod= svm.fit(X,y)
----> 4 plotSVC('kernel='+ str('polynomial'))

<ipython-input-56-556d4a22026a> in plotSVC(title)
     10     Z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
     11     Z = Z.reshape(xx.shape)
---> 12     plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
     13     plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
     14     plt.xlabel('Sepal length')

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in contourf(*args, **kwargs)
   2931                       mplDeprecation)
   2932     try:
-> 2933         ret = ax.contourf(*args, **kwargs)
   2934     finally:
   2935         ax._hold = washold

~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
   1853                         "the Matplotlib list!)" % (label_namer, func.__name__),
   1854                         RuntimeWarning, stacklevel=2)
-> 1855             return func(ax, *args, **kwargs)
   1856 
   1857         inner.__doc__ = _add_data_doc(inner.__doc__,

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in contourf(self, *args, **kwargs)
   6179             self.cla()
   6180         kwargs['filled'] = True
-> 6181         contours = mcontour.QuadContourSet(self, *args, **kwargs)
   6182         self.autoscale_view()
   6183         return contours

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in __init__(self, ax, *args, **kwargs)
    844         self._transform = kwargs.pop('transform', None)
    845 
--> 846         kwargs = self._process_args(*args, **kwargs)
    847         self._process_levels()
    848 

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _process_args(self, *args, **kwargs)
   1414                 self._corner_mask = mpl.rcParams['contour.corner_mask']
   1415 
-> 1416             x, y, z = self._contour_args(args, kwargs)
   1417 
   1418             _mask = ma.getmask(z)

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _contour_args(self, args, kwargs)
   1472             args = args[1:]
   1473         elif Nargs <= 4:
-> 1474             x, y, z = self._check_xyz(args[:3], kwargs)
   1475             args = args[3:]
   1476         else:

~/anaconda3/lib/python3.6/site-packages/matplotlib/contour.py in _check_xyz(self, args, kwargs)
   1508             raise TypeError("Input z must be a 2D array.")
   1509         elif z.shape[0] < 2 or z.shape[1] < 2:
-> 1510             raise TypeError("Input z must be at least a 2x2 array.")
   1511         else:
   1512             Ny, Nx = z.shape

TypeError: Input z must be at least a 2x2 array.

【问题讨论】:

    标签: python numpy matplotlib scikit-learn classification


    【解决方案1】:

    工作代码

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import datasets
    from sklearn.svm import SVC
    
    iris= datasets.load_iris()
    y= iris.target
    #X= iris.data[:, :2]  # sepal length and width
    X= iris.data[:, 2:]   # I tried a different X but it failed.
    
    # Ref: https://medium.com/all-things-ai/in-depth-parameter-tuning-for-svc-758215394769
    def plotSVC(title):
        # create a mesh to plot in
        x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
        y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
        h = (x_max - x_min)/100
        xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
        plt.subplot(1, 1, 1)
        z = svm_mod.predict(np.c_[xx.ravel(), yy.ravel()])
        z = z.reshape(xx.shape)
        plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
        plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
        plt.xlabel('Sepal length')
        plt.ylabel('Sepal width')
        plt.xlim(xx.min(), xx.max())
        plt.title(title)
        plt.show()
        pass
    
    svm= SVC(C= 10, kernel='poly', degree=2, coef0=1, max_iter=500000)
    svm_mod= svm.fit(X,y)
    plotSVC('kernel='+ str('polynomial'))
    

    出来:


    原因:

    除以零在h 中得到inf h = (x_max / x_min)/100 行中的h = (x_max / x_min)/100

    需要h = (x_max - x_min)/100


    我通过阅读声明的异常发现了这一点

    TypeError: 输入 z 必须至少是一个 2x2 数组。

    然后回过头看到z的shape来自xx的shape依赖于h,也就是inf,没有意义,于是就轻松解决了。

    我认为你应该学习如何更好地使用调试器。

    【讨论】:

      猜你喜欢
      • 2016-07-13
      • 2018-12-20
      • 2021-05-17
      • 2018-12-31
      • 2016-01-15
      • 2014-02-26
      • 2019-09-07
      • 2013-12-13
      • 2013-10-03
      相关资源
      最近更新 更多