我不太确定 gbm 是如何工作的,以及为什么它需要树的数量来预测输出,但这里有一个使用 pdp 和 gridExtra 包的工作示例:
library(pdp)
ntrees <- 250 # Number of trees to use to predict data
pred <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees)
mean(pred)
}
pdps1 <- pdps2 <- pdps3 <- list()
for (i in 1:3) {
pdps1[[i]] <- partial(mod0, pred.var = names(db.chao)[i+4],
train = db.chao, plot = TRUE,
pred.fun = pred, recursive = F)
pdps2[[i]] <- partial(mod1, pred.var = names(db.chao)[i+4],
train = db.chao, plot = TRUE,
pred.fun = pred, recursive = F)
pdps3[[i]] <- partial(mod2, pred.var = names(db.chao)[i+4],
train = db.chao, plot = TRUE,
pred.fun = pred, recursive = F)
}
gridExtra::grid.arrange(grobs = pdps1, nrow = 1) # For the first model
gridExtra::grid.arrange(grobs = pdps2, nrow = 1) # For the second model
gridExtra::grid.arrange(grobs = pdps3, nrow = 1) # For the third model
希望这会有所帮助!
编辑
按照 OP 的要求,仅在三个地块中获取所有 pdps,并使用不同数量的树来预测值:
library(pdp)
ntrees1 <- 150 # Number of trees to use to predict data with model1
ntrees2 <- 250 # Number of trees to use to predict data with model2
ntrees3 <- 50 # Number of trees to use to predict data with model3
pred1 <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees1)
mean(pred)
}
pred2 <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees2)
mean(pred)
}
pred3 <- function(object, newdata) {
pred <- predict(object, newdata, n.trees = ntrees3)
mean(pred)
}
# Function to obtain legend to plot later in grid.arrange
get_legend<-function(myggplot){
tmp <- ggplot_gtable(ggplot_build(myggplot))
leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
legend <- tmp$grobs[[leg]]
return(legend)
}
# Obtain partial dependence data instead of plot
pdps1 <- pdps2 <- pdps3 <- list()
plotlist <- list()
for (i in 1:3) {
# Create local environment to prevent ggplot to overwrite the plots with the iterator
local({
i <- i
pdps1[[i]] <<- partial(mod0, pred.var = names(db.chao)[i+4],
train = db.chao, plot = FALSE,
pred.fun = pred1, recursive = F)
pdps2[[i]] <<- partial(mod1, pred.var = names(db.chao)[i+4],
train = db.chao, plot = FALSE,
pred.fun = pred2, recursive = F)
pdps3[[i]] <<- partial(mod2, pred.var = names(db.chao)[i+4],
train = db.chao, plot = FALSE,
pred.fun = pred3, recursive = F)
pdp <- rbind(pdps1[[i]],pdps2[[i]],pdps3[[i]])
pdp <- cbind(pdp,rep(c("y1","y2","y3"), each = nrow(pdps1[[i]])))
names(pdp)[3] <- "#output"
plotlist[[i]] <<- ggplot(pdp) +
geom_line(aes(x = pdp[,1], y = pdp[,2],
group = pdp[,3], color = pdp[,3])) +
xlab(names(pdp)[1]) + ylab("yhat") +
ggtitle(paste0("PDP of ",names(pdp)[1])) +
labs(color = "#output")
})
legend <- get_legend(plotlist[[i]])
plotlist[[i]] <- plotlist[[i]] + theme(legend.position = "none")
}
plotlist[[4]] <- legend
gridExtra::grid.arrange(grobs = plotlist, nrow = 1, widths=c(2.3, 2.3, 2.3, 0.8))