【问题标题】:How to plot the intersection of a hyperplane and a plane in R如何在R中绘制超平面和平面的交点
【发布时间】:2015-02-12 20:57:43
【问题描述】:

我有一组(二维)数据点,我通过使用高阶多项式变换的分类器运行这些数据点。我想将结果可视化为点的二维散点图,分类器叠加在顶部,最好使用 ggplot2 ,因为所有其他可视化都是由此完成的。很像 ClatechX 机器学习在线课程中使用的这个(背景颜色是可选的)。

我可以用颜色和符号来显示点,这很简单,但我不知道如何绘制分类器之类的东西(分类超平面与代表我的阈值的平面的交点)。我发现的唯一东西是 stat_function,它只接受一个带有单个参数的函数。

编辑:

在 cmets 中要求的示例:

样本数据:

"","x","y","x","x","y","value"
"1",4.17338115745224,0.303530843229964,1.26674990184152,17.4171102853774,0.0921309727918932,-1
"2",4.85514814266935,3.452660451876,16.7631779801937,23.5724634872656,11.9208641959486,1
"3",3.51938610081561,3.41200957307592,12.0081790673332,12.3860785266141,11.6418093267617,1
"4",3.18545089452527,0.933340128976852,2.97310914874565,10.1470974014319,0.87112379635852,-16
"5",2.77556006214581,2.49701633118093,6.93061880335166,7.70373365857888,6.23509055818427,-1
"6",2.45974169578403,4.56341833807528,11.2248303614692,6.05032920997851,20.8247869282818,1
"7",2.73947941488586,3.35344674880616,9.18669833727041,7.50474746458339,11.2456050970786,-1
"8",2.01721803518012,3.55453519499861,7.17027250203368,4.06916860145595,12.6347204524838,-1
"9",3.52376445778646,1.47073399974033,5.1825201951431,12.4169159539591,2.1630584979922,-1
"10",3.77387718763202,0.509284208528697,1.92197605658768,14.2421490273294,0.259370405056702,-1
"11",4.15821685106494,1.03675272315741,4.31104264382058,17.2907673804804,1.0748562089743,-1
"12",2.57985028671101,3.88512040604837,10.0230289934507,6.65562750184287,15.0941605694935,1
"13",3.99800728890114,2.39457673509605,9.5735352407471,15.9840622821066,5.73399774026327,1
"14",2.10979392635636,4.58358959294856,9.67042948411309,4.45123041169019,21.0092935565863,1
"15",2.26988795562647,2.96687697409652,6.73447830932721,5.15239133109813,8.80235897942413,-1
"16",1.11802248633467,0.114183261757717,0.127659454208164,1.24997427994995,0.0130378172656312,-1
"17",0.310411276295781,2.09426849964075,0.650084557879535,0.0963551604515758,4.38596054858751,-1
"18",1.93197490065359,1.72926536411978,3.340897280049,3.73252701675543,2.99035869954433,-1
"19",3.45879891654477,1.13636834081262,3.93046958599847,11.9632899450912,1.29133300600123,-1
"20",0.310697768582031,0.730971727753058,0.227111284709427,0.0965331034018534,0.534319666774291,-1
"21",3.88408110360615,0.915658151498064,3.55649052359657,15.0860860193904,0.838429850404852,-1
"22",0.287852146429941,2.16121324687265,0.622109872005114,0.0828588582043242,4.67084269845782,-1
"23",2.80277011333965,1.22467750683427,3.4324895146344,7.85552030822994,1.4998349957458,-1
"24",0.579150241101161,0.57801398797892,0.334756940497835,0.335415001767533,0.334100170299295-,1
"25",2.37193428212777,1.58276639413089,3.7542178708388,5.62607223873297,2.50514945839009,-1
"26",0.372461311053485,2.51207412336953,0.935650421453748,0.138727428231681,6.31051640130279,-1
"27",3.56567220995203,1.03982002707198,3.70765737388213,12.7140183088242,1.08122568869998,-1
"28",0.634770628530532,2.26303249713965,1.43650656059435,0.402933750845047,5.12131608311011,-1
"29",2.43812176748179,1.91849716124125,4.67752968967431,5.94443775306852,3.68063135769073,-1
"30",1.08741064323112,3.01656032912433,3.28023980783858,1.18246190701233,9.0996362192467,-1
"31",0.98,2.74,2.6852,0.9604,7.5076,1
"32",3.16,1.78,5.6248,9.9856,3.1684,1
"33",4.26,4.28,18.2328,18.1476,18.3184,-1

生成分类器的代码:

perceptron_train <- function(data, maxIter=10000) {
    set.seed(839)
    X <- as.matrix(data[1:5])
    Y <- data["value"]
    d <- dim(X)
    X <- cbind(rep(1, d[1]), X)
    W <- rep(0, d[2] + 1)
    count <- 0
    while (count < maxIter){
        H <- sign(X %*% W)
        indexs <- which(H != Y)
        if (length(indexs) == 0){
            break
        } else {
            i <- sample(indexs, 1)
            W <- W + 0.1 * (X[i,] * Y[i,])
        }
        count <- count + 1
        point <- as.data.frame(data[i,])
        plot_it(data, point, W, paste("plot", sprintf("%05d", count), ".png", sep=""))
    }
    W
}    

生成图的代码:

plot_it <- function(data, point, weights, name = "plot.png") {
    line <- weights_to_line(weights)
    point <- point
    png(name)
    p = ggplot() + geom_point(data = data, aes(x, y, color = value, size = 2)) + theme(legend.position = "none")
    p = p + geom_abline(intercept = line[2], slope = line[1])
    print(p)
    dev.off()
}

【问题讨论】:

  • 看看stat_contour(...)。除此之外,您还需要发布一个可重现的示例,其中包含(至少)您的数据的代表性样本,以及您用于生成分类的代码。请参阅:stackoverflow.com/questions/5963269/…
  • 我不确定这应该如何帮助回答这个问题,因为它不是关于生成分类器的数据或代码,而是你。
  • 这个例子还不完整。它缺少像weights_to_line 这样的函数定义。您将绘图包含在分类器的迭代中也很奇怪。正如 jlhoward 所说,这样的图最终会看起来像等高线图。您将需要为每个 x/y 输入组合设置一个响应值,以了解边界在哪里。这里没有足够的信息来帮助你。
  • @MrFlick 当前示例将在演示文稿中用于展示这些边界如何随着迭代而演变。我希望避免一遍又一遍地计算数以千计的数据点,因为这似乎非常低效且不切实际,但我会试一试。谢谢

标签: r ggplot2 classification


【解决方案1】:

使用来自Issues plotting a fitted SVM model's decision boundary using ggplot2's stat_contour() 的问题和答案中的材料解决了这个问题。我跳过了对网格整体的 geom_point 的调用以及一些美学定义,如 scale_fill_manual 和 scale_colour_manual。在我的例子中,删除网格条目的点解决了轮廓线消失的问题。

train_and_plot_svm <- function(train, kernel = "sigmoid", type ="C", cost, gamma) {
    fit <- svm(as.factor(value) ~ x + y, data = train, kernel = kernel, type = type, cost = cost)
    grid <- expand.grid (x = seq(from = -0.1, to = 15, length = 100), y = seq(from = -0.1, to = 15, length = 100))
    decisionValues <- as.vector(attributes(predict(fit, grid, decision.values = TRUE))$decision)
    p <- predict(fit, grid)
    grid$value <- p
    grid$z <- decisionValues
    p <- ggplot() + stat_contour(data = grid, aes(x = x, y = y, z = z), breaks = c(0)) 
    p <- p + geom_point(data = train, aes(x, y, colour = as.factor(value)), alpha = 0.7)
    p <- p + xlim(0,15) + ylim(0,15) + theme(legend.position="none")
}

请注意,此函数返回的不是 svm 训练的结果,而是 ggplot2 对象。

这就是我得到的:

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-06-19
    • 2021-05-08
    • 2013-01-27
    • 2014-07-13
    • 2021-12-08
    • 2020-02-04
    相关资源
    最近更新 更多