【问题标题】:Calculate number of observations in each node in a decision tree in R?计算R中决策树中每个节点的观察数?
【发布时间】:2022-02-06 10:24:17
【问题描述】:

已提出类似问题,例如herehere,但其他问题均不适用于我的问题。我试图确定和计算决策树中每个节点中的观察值。但是,树结构来自我自己从 BART 包创建的树数据框。我从BART 包中提取树信息并将其转换为类似于下图所示的数据框(即df)。但我需要使用提供的数据框结构。另外:我相信我使用的方法,与我的数据框中的树如何绘制/排序有关,被称为“深度优先”。

例如,我的树数据框如下所示:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

在视觉上,这些树看起来像:

在向下遍历df 时,正在左前绘制树木。此外,所有拆分都是二进制拆分。所以每个节点都会有 2 个孩子。

所以,如果我们创建一些如下所示的数据:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)

我试图找出dat 的哪些观察结果属于哪个节点?

尝试回答: 这并没有真正的帮助,但为了清楚起见(因为我仍在尝试解决这个问题),为树号 3 硬编码如下所示:

lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]

 namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
    dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
    dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
    dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
    dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
    dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
    dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]

我一直试图把它变成一个循环。但事实证明,锻炼很有挑战性。 而且我(显然)不能为每棵树硬编码它。关于如何解决这个问题的任何建议?

【问题讨论】:

  • 为什么不在树的生长过程中计算给定节点中每个观察值的数量,并将其作为新变量添加到数据框中?我在编写决策树时就是这样做的。此外,我们不知道每个变量应该从树中取出的路径。例如,我们从df 看到,第一棵树在x1x2 上分裂了一次,但从数据帧中并不清楚按照哪个顺序,如果按顺序,ecc。
  • 那么你应该提供更多关于你实际在做什么的信息。根据我们所拥有的,即使只绘制您所附的图片也是不可能的。你可以编辑你的帖子添加一个可重现的例子吗?
  • 对于初学者,我将添加您从中提取树数据的几个不同的决策树包,以及至少一个您如何提取它们的示例。
  • 我应该在我的问题中提到这一点。它是一棵二叉树,所以每个节点都会有两个孩子。我会将该信息添加到问题中
  • 如果您使用 BART 生成树,请发布您的 BART 代码

标签: r decision-tree


【解决方案1】:

似乎我们可以通过“滚动拆分”来获得您想要的东西。逻辑如下。

  1. 从只有一个数据帧 dat 的堆栈开始。
  2. 对于每对variableNamesplitValue,如果它们不是NAs,则将该堆栈上的顶部数据帧拆分为两个由variableName &lt;= splitValuevariableName &gt; splitValue 标识的子数据帧(前者在顶部后者);如果它们是NAs,则只需弹出顶部数据框。

这里是代码。请注意,这种依赖于状态的计算很难向量化。因此,这不是 R 擅长的。如果你有很多树并且代码性能成为一个严重的问题,我建议使用Rcpp重写下面的代码。

eval_node <- function(df, x, v) {
  out <- vector("list", length(x))
  stk <- vector("list", sum(is.na(x)))
  pos <- 1L
  stk[[pos]] <- df
  for (i in seq_along(x)) {
    if (!is.na(x[[i]])) {
      subs <- pos + c(0L, 1L)
      stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
      names(stk)[subs] <- trimws(paste0(
        names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
      ), "left", ",")
      out[[i]] <- rev(stk[subs])
      pos <- pos + 1L
    } else {
      out[[i]] <- stk[pos]
      stk[[pos]] <- NULL
      pos <- pos - 1L
    }
  }
  out
}

然后你可以像这样应用函数。

library(dplyr)

df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))

输出

# A tibble: 15 x 4
# Groups:   treeNo [3]
   variableName splitValue treeNo node            
   <chr>             <dbl>  <dbl> <list>          
 1 x2                0.542      1 <named list [2]>
 2 x1                0.126      1 <named list [2]>
 3 NA               NA          1 <named list [1]>
 4 NA               NA          1 <named list [1]>
 5 NA               NA          1 <named list [1]>
 6 x2                0.655      2 <named list [2]>
 7 NA               NA          2 <named list [1]>
 8 NA               NA          2 <named list [1]>
 9 x5                0.418      3 <named list [2]>
10 x4                0.234      3 <named list [2]>
11 NA               NA          3 <named list [1]>
12 NA               NA          3 <named list [1]>
13 x3                0.747      3 <named list [2]>
14 NA               NA          3 <named list [1]>
15 NA               NA          3 <named list [1]>

node 看起来像这样

[[1]]
[[1]]$`x2<=0.542`
          x1        x2        x3        x4        x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[1]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034

[[2]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034


[[4]]
[[4]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[5]]
[[5]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[6]]
[[6]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[6]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[7]]
[[7]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139


[[8]]
[[8]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[9]]
[[9]]$`x5<=0.418`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9  0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[9]]$`x5>0.418`
          x1        x2        x3        x4        x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[10]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[12]]
[[12]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[13]]
[[13]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318

[[13]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[14]]
[[14]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318


[[15]]
[[15]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270

【讨论】:

  • 这是一个很好的答案。我正在尝试弄清楚在使用您的代码时如何从node 中提取行索引。棘手的部分是node 对象有重复项。例如df$node[[2]] 包含与df$node[[3]]df$node[[4]] 组合的相同信息。我正在尝试删除任何重复项。我最终想看看有多少次观察对(即行索引)出现在同一个节点中。例如,观察 3 和 4 一起出现 3 次。但是由于重复,观察 3 和 4 在 node 对象中出现了 4 次。但这是一个很好的开始
  • 我不认为它们是重复的。尽管node 在某种程度上是一个名称错误的列,但它实际上只是为您提供每个variableName 所需的所有信息。请注意,variableName 可以表示树中的节点或叶(即NA)。节点的信息是每个分支的数据,叶子的信息是分支本身;这就是为什么我们需要一些重复。如果你想“隐藏”叶子的信息,注释掉这行out[[i]] &lt;- stk[pos]。如果你想从数据中删除叶子,那么df %&gt;% filter(!is.na(variableName))。 @Electrino
  • 你是对的!当我评论重复项时,我的想法不正确。你没有重复。我将此标记为正确答案。它解决了我的问题。感谢您的时间和精力
【解决方案2】:

还有很大的优化空间,不过这是我的尝试。您的树似乎以深度优先的方式构造,左子节点始终跟随父节点:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

给定要匹配的数据:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)
dat
##>           x1        x2        x3        x4        x5
##>1  0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
##>2  0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
##>3  0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
##>4  0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
##>5  0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
##>6  0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
##>7  0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
##>8  0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
##>9  0.54655860 0.3594751 0.5490967 0.9895641 0.2077139
##>10 0.17026205 0.6902905 0.2777238 0.1302889 0.3070859

makeTree 是一个高阶函数,它返回一个函数,该函数又将一行值映射到一个节点:

makeTree <- function(dat, r = 1) {
  ## the argument dat is a dataframe representation
  ## of a single tree as in the example
  ## return a list of two elements: size and fn. 
  ## - size is the number of cells taken by the 
  ##   node and its descendants. 
  ## - fn is a function of one argument (either a list or
  ##   a row of a dataframe) that returns the index of the 
  ##   node matching argument. More precisely the column Id
  ##   in dat.    
  stopifnot(r <= nrow(dat))
  vname <- pull(dat,variableName)[r]
  splitVal <- pull(dat, splitValue)[r]
  if (is.na(vname)) {
    ## terminal node
    ## print(sprintf("terminal node: %i", r))
    res <- list(size = 1, # offset to access right node
                fn = function(z) {
                  pull(dat, "id")[r]
                })
    return(res)
  } else {
    ##print(sprintf("node: %i, varName: %s, splitVal: %f", r, vname, splitVal ))
    ## compute the left and right functions
    ## note that the tree is traversed depth-first 
    fnleft <- makeTree(dat, r + 1) #fnleft is always positoned next to the
                                   #caller
    fnright <- makeTree(dat, r + fnleft$size + 1 )
    return(list(size = fnleft$size + fnright$size + 1,
                fn = function(z) {
                  if (z[vname] <= splitVal)
                    fnleft$fn(z)
                  else
                    fnright$fn(z)
                }))
  }
}

现在makeTree 应用于每棵树以生成匹配函数列表:

treefns <- df |>
  mutate(id = row_number()) %>%
  group_by(treeNo) |>
  group_split()    |>
  purrr::map(makeTree) |>
  purrr::map("fn")

最后,数据框dat 的每一行都匹配到树的一个节点:

apply(dat,1, function(z) sapply(treefns, function(fn) fn(z))) |>
  t() |>
  data.frame() |>
  rename_with(function(z) paste0("TREE", gsub("X", "", z))) |>
  cbind(dat) |>
  pivot_longer(cols = starts_with("TREE"),
               names_to = "TREE",
               values_to = "NODE")  |>
  sample_n(10)

##> A tibble: 10 x 7
##>       x1    x2    x3    x4    x5 TREE   NODE
##>    <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <int>
##> 1 0.170  0.690 0.278 0.130 0.307 TREE3    11
##> 2 0.170  0.690 0.278 0.130 0.307 TREE2     8
##> 3 0.370  0.358 0.882 0.629 0.884 TREE2     7
##> 4 0.308  0.625 0.536 0.488 0.331 TREE1     5
##> 5 0.370  0.358 0.882 0.629 0.884 TREE1     4
##> 6 0.552  0.280 0.538 0.349 0.778 TREE3    14
##> 7 0.547  0.359 0.549 0.990 0.208 TREE1     4
##> 8 0.370  0.358 0.882 0.629 0.884 TREE3    15
##> 9 0.547  0.359 0.549 0.990 0.208 TREE2     7
##>10 0.0564 0.398 0.749 0.954 0.827 TREE2     7

【讨论】:

    猜你喜欢
    • 2012-11-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-01-28
    • 2019-02-08
    • 2014-07-20
    • 2015-01-31
    相关资源
    最近更新 更多