【问题标题】:Random Forest Regression: Extracting the training samples in the terminal nodes of each tree随机森林回归:在每棵树的终端节点中提取训练样本
【发布时间】:2020-12-13 06:06:34
【问题描述】:

我想实施 Bertsimas et al. (2020)Predictive Prescription 方法,他们将机器学习方法与优化相结合。为此,我需要查看森林中每棵决策树的终端节点(分离区域)。

具体来说,我想知道每棵树的以下几点:

  1. 训练样本属于哪个区域?
  2. 测试样本属于哪个地区?

我希望通过下面一张决策树的图片让我的问题变得更清楚:

Regression Tree Example

这里,对于第一个终端节点,我对预测 m 不感兴趣,而是对构成预测基础的值 y1、y4 和 y5 感兴趣。


完美的结果将是一个类似矩阵的结构,其中每一列代表一棵树,每一行代表一个训练(测试)样本。对于每个样本和树,结构应该给我可以找到样本的区域/终端节点的 ID!

我查看了randomForestranger 包,但没有找到任何相关的东西......一些论文提到使用caret 包实现该方法,但他们没有提及如何绕过预测。

这是一个使用 ranger 的可重现回归示例:

library(MASS)
library(e1071)
library(ranger)

#load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))

train <- Boston[ind == 1,]
test <- Boston[ind == 2,]

#train random forest
boston.rf <- ranger(medv ~ ., data = train) 

非常感谢任何帮助。干杯!

【问题讨论】:

  • randomForest 没有一棵树(默认为 500 棵),因此有 500 个不同的区域组合在一起以获得最终答案。您的问题没有简单的答案。

标签: r regression random-forest


【解决方案1】:

到目前为止,我发现获取此信息的一种方法是使用带有选项 keep.inbag=TrandomForest 包 - 这允许您检索用于创建每棵树的样本的信息 - 以及方法getTree 检索森林中每棵树的树结构。

我创建了一个函数来检索给定来自getTree 的树结构的终端节点 ID。

# function to retrieve the terminal node id given a rf tree structure and a sample (with numerical only features)
get_terminal_node_id_for_sample <- function(tree, sample){
  node_id=1
  search <- TRUE
  while(search){
    if(tree$status[node_id]=="-1"){
      search <- FALSE
      break
    }
    if(sample[as.character(tree$split.var[node_id])] < tree$split.point[node_id]){
      node_id <- as.numeric(tree$left.daughter[node_id])
    } else {
      node_id <- as.numeric(tree$right.daughter[node_id])
    }
  }
  return(node_id)
}

并像这样使用它:

library(randomForest)
library(MASS)
library(e1071)

# load data
data(Boston)
set.seed(111)
ind <- sample(2, nrow(Boston), replace = TRUE, prob=c(0.8, 0.2))

train <- Boston[ind == 1,]
test <- Boston[ind == 2,]

# train random forest and keep inbag information
model = randomForest(medv~.,data = train,
                     keep.inbag=T)

# get the first tree of the forest
treeind <- 1
tree <- data.frame(getTree(model, k=treeind, labelVar=TRUE))

# loop over each sample in inbag of the first tree
for (sampleind in which(model$inbag[,treeind]>0)){
  sample <- train[sampleind,]
  node_id <- get_terminal_node_id_for_sample(tree,sample)
  
  ##########################
  # do whatever with node_id
  ##########################
  
  print(paste("sample",sampleind,"is in terminal node",node_id,sep=" "))
}

需要说明的是:我仅针对数字特征对此进行了测试。

【讨论】:

    猜你喜欢
    • 2014-04-20
    • 2018-12-06
    • 1970-01-01
    • 2013-08-30
    • 1970-01-01
    • 2018-05-27
    • 2019-10-23
    • 2020-02-12
    相关资源
    最近更新 更多