Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Calculate number of observations in each node in a decision tree in R?

Similar questions have been asked, for example here and here but none of the other questions can be applied to my issue. Im trying to determine and count which observations are in each node in a decision tree. However, the tree structure is coming from a data frame of trees that Im creating myself from the BART package. Im extracting tree information from BART package and turning it into a data frame that resembles the one shown below (i.e., df). But I need to work with the data frame structure provided. Aside: I believe the method im using, in relation to how the trees are drawn/ordered in my data frame, is called 'depth first'.

For example, my data frame of trees looks like this:

library(dplyr)
df <- tibble(variableName = c("x2", "x1", NA, NA, NA, "x2", NA, NA, "x5", "x4", NA, NA, "x3", NA, NA),
             splitValue = c(0.542, 0.126, NA, NA, NA, 0.6547, NA, NA, 0.418, 0.234, NA, NA, 0.747, NA, NA),
             treeNo = c(1,1,1,1,1,2,2,2,3,3,3,3,3,3,3))

Visually, these trees would look like:

decision trees

The trees are being drawn left-first when traversing down df. Additionally, all splits are binary splits. So each node will have 2 children.

So, if we create some data that looks like this:

set.seed(100)
dat <- data.frame( x1 = runif(10),
                   x2 = runif(10),
                   x3 = runif(10),
                   x4 = runif(10),
                   x5 = runif(10)
)

Im trying to find which of the observations of dat fall into which node?

Attempt at an answer: This isn't really helpful, but for clarity (as I am still trying to solve this), hardcoding it for tree number three would look like this:

lists <- df %>% group_by(treeNo) %>% group_split()
tree<- lists[[3]]

 namesDf <- names(dat[grepl(tree[1, ]$variableName, names(dat))])
    dataLeft <- dat[dat[, namesDf] <= tree[1,]$splitValue, ]
    dataRight <- dat[dat[, namesDf] > tree[1,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[2, ]$variableName, names(dat))])
    dataLeft1 <- dataLeft[dataLeft[, namesDf] <= tree[2,]$splitValue, ]
    dataRight1 <- dataLeft[dataLeft[, namesDf] > tree[2,]$splitValue, ]
    
    namesDf <- names(dat[grepl(tree[5, ]$variableName, names(dat))])
    dataLeft2 <- dataRight[dataRight[, namesDf] <= tree[5,]$splitValue, ]
    dataRight2 <- dataRight[dataRight[, namesDf] > tree[5,]$splitValue, ]

I have been trying to maybe turn this into a loop. But it's proving to be challenging to work out. And I (obviously) cant hardcode it for every tree. Any suggestions as to how I could solve this??

like image 824
Electrino Avatar asked Oct 15 '25 04:10

Electrino


1 Answers

It seems that we can do "rolling splits" to get what you are looking for. The logic is as follows.

  1. Start with a stack with only one dataframe dat.
  2. For each pair of variableName and splitValue, if they are not NAs, split the top dataframe on that stack into two sub dataframes identified by variableName <= splitValue and variableName > splitValue (the former on top of the latter); if they are NAs, then simply pop the top dataframe.

Here is the code. Note that this kind of state-dependent computation is hard to vectorize. It's thus not what R is good at. If you have a lot of trees and the code performance becomes a serious concern, I'd suggest rewriting the code below using Rcpp.

eval_node <- function(df, x, v) {
  out <- vector("list", length(x))
  stk <- vector("list", sum(is.na(x)))
  pos <- 1L
  stk[[pos]] <- df
  for (i in seq_along(x)) {
    if (!is.na(x[[i]])) {
      subs <- pos + c(0L, 1L)
      stk[subs] <- split(stk[[pos]], stk[[pos]][[x[[i]]]] <= v[[i]])
      names(stk)[subs] <- trimws(paste0(
        names(stk[pos]), ",", x[[i]], c(">", "<="), v[[i]]
      ), "left", ",")
      out[[i]] <- rev(stk[subs])
      pos <- pos + 1L
    } else {
      out[[i]] <- stk[pos]
      stk[[pos]] <- NULL
      pos <- pos - 1L
    }
  }
  out
}

Then you can apply the function like this.

library(dplyr)

df %>% group_by(treeNo) %>% mutate(node = eval_node(dat, variableName, splitValue))

Output

# A tibble: 15 x 4
# Groups:   treeNo [3]
   variableName splitValue treeNo node            
   <chr>             <dbl>  <dbl> <list>          
 1 x2                0.542      1 <named list [2]>
 2 x1                0.126      1 <named list [2]>
 3 NA               NA          1 <named list [1]>
 4 NA               NA          1 <named list [1]>
 5 NA               NA          1 <named list [1]>
 6 x2                0.655      2 <named list [2]>
 7 NA               NA          2 <named list [1]>
 8 NA               NA          2 <named list [1]>
 9 x5                0.418      3 <named list [2]>
10 x4                0.234      3 <named list [2]>
11 NA               NA          3 <named list [1]>
12 NA               NA          3 <named list [1]>
13 x3                0.747      3 <named list [2]>
14 NA               NA          3 <named list [1]>
15 NA               NA          3 <named list [1]>

, where node looks like this

[[1]]
[[1]]$`x2<=0.542`
          x1        x2        x3        x4        x5
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[1]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[2]]
[[2]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034

[[2]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[3]]
[[3]]$`x2<=0.542,x1<=0.126`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034


[[4]]
[[4]]$`x2<=0.542,x1>0.126`
         x1        x2        x3        x4        x5
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
7 0.8124026 0.2046122 0.7703016 0.1804072 0.7803585
8 0.3703205 0.3575249 0.8819536 0.6293909 0.8842270
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[5]]
[[5]]$`x2>0.542`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[6]]
[[6]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139

[[6]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[7]]
[[7]]$`x2<=0.6547`
          x1        x2        x3        x4        x5
1 0.30776611 0.6249965 0.5358112 0.4883060 0.3306605
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
9 0.54655860 0.3594751 0.5490967 0.9895641 0.2077139


[[8]]
[[8]]$`x2>0.6547`
          x1        x2        x3        x4        x5
2  0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
5  0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6  0.4837707 0.6690217 0.1714202 0.8894535 0.4912318
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[9]]
[[9]]$`x5<=0.418`
          x1        x2        x3        x4        x5
1  0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9  0.5465586 0.3594751 0.5490967 0.9895641 0.2077139
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[9]]$`x5>0.418`
          x1        x2        x3        x4        x5
2 0.25767250 0.8821655 0.7108038 0.9285051 0.8651205
3 0.55232243 0.2803538 0.5383487 0.3486920 0.7775844
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
5 0.46854928 0.7625511 0.4201015 0.6952741 0.6033244
6 0.48377074 0.6690217 0.1714202 0.8894535 0.4912318
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[10]]
[[10]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859

[[10]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[11]]
[[11]]$`x5<=0.418,x4<=0.234`
          x1        x2        x3        x4        x5
10 0.1702621 0.6902905 0.2777238 0.1302889 0.3070859


[[12]]
[[12]]$`x5<=0.418,x4>0.234`
         x1        x2        x3        x4        x5
1 0.3077661 0.6249965 0.5358112 0.4883060 0.3306605
9 0.5465586 0.3594751 0.5490967 0.9895641 0.2077139


[[13]]
[[13]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318

[[13]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270


[[14]]
[[14]]$`x5>0.418,x3<=0.747`
         x1        x2        x3        x4        x5
2 0.2576725 0.8821655 0.7108038 0.9285051 0.8651205
3 0.5523224 0.2803538 0.5383487 0.3486920 0.7775844
5 0.4685493 0.7625511 0.4201015 0.6952741 0.6033244
6 0.4837707 0.6690217 0.1714202 0.8894535 0.4912318


[[15]]
[[15]]$`x5>0.418,x3>0.747`
          x1        x2        x3        x4        x5
4 0.05638315 0.3984879 0.7489722 0.9541577 0.8273034
7 0.81240262 0.2046122 0.7703016 0.1804072 0.7803585
8 0.37032054 0.3575249 0.8819536 0.6293909 0.8842270
like image 70
ekoam Avatar answered Oct 16 '25 17:10

ekoam



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!