【发布时间】:2020-01-06 06:49:21
【问题描述】:
我正在尝试基于此documentation 在sklearn Python 中使用LDA 绘制鸢尾花数据集 的边界线。
对于二维数据,我们可以使用LDA.coef_ 和LDA.intercept_ 轻松绘制线条。
但是对于已减少为两个组件的多维数据,LDA.coef_ 和 LDA.intercept 有很多维度,我不知道如何使用这些维度来绘制 2D 中的边界线降维图。
我尝试仅使用 LDA.coef_ 和 LDA.intercept 的前两个元素进行绘图,但没有成功。
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
iris = datasets.load_iris()
X = iris.data
y = iris.target
target_names = iris.target_names
lda = LinearDiscriminantAnalysis(n_components=2)
X_r2 = lda.fit(X, y).transform(X)
x = np.array([-10,10])
y_hyperplane = -1*(lda.intercept_[0]+x*lda.coef_[0][0])/lda.coef_[0][1]
plt.figure()
colors = ['navy', 'turquoise', 'darkorange']
lw = 2
plt.plot(x,y_hyperplane,'k')
for color, i, target_name in zip(colors, [0, 1, 2], target_names):
plt.scatter(X_r2[y == i, 0], X_r2[y == i, 1], alpha=.8, color=color,
lw=lw,
label=target_name)
plt.legend(loc='best', shadow=False, scatterpoints=1)
plt.title('LDA of IRIS dataset')
plt.show()
lda.coef_[0] 和 lda.intercept[0] 生成的边界线结果显示了一条不太可能在两个类之间分开的线
我尝试使用 np.meshgrid 来绘制类的区域。但是我收到这样的错误
ValueError: X 每个样本有 2 个特征;期待 4
它需要 4 维的原始数据,而不是来自网格网格的 2D 点。
【问题讨论】:
-
看起来你的代码是从this改编的,对吗?我敢打赌,问题在于您的绘图点已被转换(以最大化类分离),而分离平面位于原始坐标中。此外,考虑模拟一些简单的二维数据,并确保您了解输出,并且可以在这种情况下绘制分离平面。 然后继续调试。
-
@bogovicj,是的,它是来自 sklearn 文档的代码。我已经知道如何使用 coef_ 和 intercept_ 的 2D 元素从 2D 数据中绘制分离平面;然后根据上面代码中的 y_hyperplane 方程直接绘制。是的,问题是文档将四维虹膜数据的原始数据转换为二维 LDA转换,以便我们可以可视化结果。但是coef_和intercept_的元素也是四维,所以我很困惑如何使用这些元素来绘制超平面。
标签: python plot scikit-learn svm linear-discriminant