Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficiently apply sample() in R

I need to sample an outcome variable given a matrix with row-wise outcome probabilities.

set.seed(1010) #reproducibility

#create a matrix of probabilities
#three possible outcomes, 10.000 cases
probabilities <- matrix(runif(10000*3),nrow=10000,ncol=3)
probabilities <- probabilities / Matrix::rowSums(probabilities)

The fastest way I could come up with is a combination of apply() and sample().

#row-wise sampling using these probabilities
classification <- apply(probabilities, 1, function(x) sample(1:3, 1, prob = x))

However, in what I'm doing, this is the computational bottleneck. Do you have an idea how to speed this code up / how to sample more efficiently?

Thanks!

like image 845
yrx1702 Avatar asked Oct 28 '25 10:10

yrx1702


1 Answers

RLave's comment that Rcpp could be the way to go is spot on (you also need RcppArmadillo for sample()); I used the following C++ code to create such a function:

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadilloExtensions/sample.h>

using namespace Rcpp;

// [[Rcpp::export]]
IntegerVector sample_matrix(NumericMatrix x, IntegerVector choice_set) {
    int n = x.nrow();
    IntegerVector result(n);
    for ( int i = 0; i < n; ++i ) {
        result[i] = RcppArmadillo::sample(choice_set, 1, false, x(i, _))[0];
    }
    return result;
}

I then made that function available in my R session via

Rcpp::sourceCpp("sample_matrix.cpp")

Now we can test it in R against your initial approach, as well as the other suggestions to use purrr::map() and lapply():

set.seed(1010) #reproducibility

#create a matrix of probabilities
#three possible outcomes, 10.000 cases
probabilities <- matrix(runif(10000*3),nrow=10000,ncol=3)
probabilities <- probabilities / Matrix::rowSums(probabilities)
probabilities_list <- split(probabilities, seq(nrow(probabilities)))

library(purrr)
library(microbenchmark)

microbenchmark(
    apply = apply(probabilities, 1, function(x) sample(1:3, 1, prob = x)),
    map = map(probabilities_list, function(x) sample(1:3, 1, prob = x)),
    lapply = lapply(probabilities_list, function(x) sample(1:3, 1, prob = x)),
    rcpp = sample_matrix(probabilities, 1:3),
    times = 100
)

Unit: milliseconds
   expr       min        lq      mean    median        uq       max neval
  apply 307.44702 321.30051 339.85403 342.36421 350.86090 434.56007   100
    map 254.69721 265.10187 282.85592 286.21680 295.48886 363.95898   100
 lapply 249.68224 259.70178 280.63066 279.87273 287.10062 691.21359   100
   rcpp  12.16787  12.55429  13.47837  13.81601  14.25198  16.84859   100
 cld
   c
  b 
  b 
 a  

The time savings are considerable.

like image 127
duckmayr Avatar answered Oct 31 '25 02:10

duckmayr



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!