Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to assign 0 to the minimum values by row in a matrix (in a fast/efficient way)?

One has a matrix Lambda with p columns and n rows and for each row wants assign all the values to 0 except the value in the first column and the maximum of the values in the other columns (in that sense all the p - 2 minimum values after avoiding the first column).

For the moment I am doing this with a for loop, like follows:

set.seed(60)
(Lambda = matrix(sample.int(30),5))
     [,1] [,2] [,3] [,4] [,5] [,6]
[1,]   19   20   27   18   15   25
[2,]   16   28    1    4   22    7
[3,]    2   10    8   23    3   12
[4,]    5    6    9   17   11   29
[5,]   26   30   24   13   14   21

m <- ncol(Lambda) - 2
for(ir in seq_len(nrow(Lambda))){
    Lambda[ir, match(tail(sort(abs(Lambda[ir, 2:ncol(Lambda)]), decreasing = TRUE), m), Lambda[ir,])] <- 0
}
Lambda
     [,1] [,2] [,3] [,4] [,5] [,6]
[1,]   19    0   27    0    0    0
[2,]   16   28    0    0    0    0
[3,]    2    0    0   23    0    0
[4,]    5    0    0    0    0   29
[5,]   26   30    0    0    0    0

Fine, one gets the goal, but if there were many rows it would become a problem. Is there a solution not using a for loop? It could be with lapply but I'm not sure if it would be really efficient. Maybe with data.table after converting the matrix?

Thank you!

like image 891
iago Avatar asked Sep 01 '25 20:09

iago


2 Answers

Here is one option that is more than 15% seems a bit faster than proposal() on 600k rows:

foo <- function(Lambda) {
  nr <- nrow(Lambda)
  keep <- c(seq_len(nr), apply(Lambda[, -1], 1, which.max)*nr + seq_len(nr))
  replace(Lambda, -keep, 0L)
}

Edit

A vast improvment is replacing the apply() + which.max() combo with max.col() as suggested by markus:

foo2 <- function(Lambda) {
  nr <- nrow(Lambda)
  keep <- c(seq_len(nr), max.col(Lambda[, -1], ties.method = "first")*nr + seq_len(nr))
  replace(Lambda, -keep, 0L)  
}

(Updated) Benchmark:

set.seed(60)
Lambda = matrix(sample.int(36e5), ncol = 6)
bench::mark(
  foo(Lambda),
  proposal(Lambda),
  foo2(Lambda),
  relative = TRUE
)[1:5]

  expression         min median `itr/sec` mem_alloc
  <bch:expr>       <dbl>  <dbl>     <dbl>     <dbl>
1 foo(Lambda)       17.7   12.1      1.09      3.67
2 proposal(Lambda)  19.3   13.1      1         1   
3 foo2(Lambda)       1      1       13.5       3.75
like image 120
sindri_baldur Avatar answered Sep 03 '25 21:09

sindri_baldur


So concerning

Is there a solution not using a for loop.

For some algorithms, you just have to write a for loop. And that's okay! Swapping a for loop for something like lapply is not really a performance improvement (see https://stackoverflow.com/a/42440872/4917834).

It is possible to speed up your code though:

# your example
set.seed(60)
Lambda = matrix(sample.int(30),5)

original <- function(Lambda) {
  m <- ncol(Lambda) - 2
  for (ir in seq_len(nrow(Lambda))){
    Lambda[ir, match(tail(sort(abs(Lambda[ir, 2:ncol(Lambda)]), decreasing = TRUE), m), Lambda[ir,])] <- 0
  }
  Lambda
}
original(Lambda)

# a faster alternative
proposal <- function(Lambda) {

  nc <- ncol(Lambda)
  for (i in seq_len(nrow(Lambda))) {
    m <- which.max(abs(Lambda[i, -1L]))
    Lambda[i, (2:nc)[-m]] <- 0
  }
  Lambda
}
proposal(Lambda)

Let's benchmark the two approaches:

bch <- bench::mark(
  original(Lambda),
  proposal(Lambda)
)
summary(bch, relative = TRUE)
# A tibble: 2 x 13
  expression              min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc 
  <bch:expr>            <dbl>  <dbl>     <dbl>     <dbl>    <dbl> <int> <dbl> 
1 original(Lambda)       25.7   24.1       1           1     1     1447     4
2 proposal(Lambda)        1      1        23.6         1     2.57  9997     3

So proposal is about 24 times faster than your original solution (median time for original is 313.8µs, for proposal it's 13.1µs). If that's not fast enough it might be worthwhile to look for a package that has implemented this in C or C++. I played around with matrixStats but no luck. Alternatively, you could port this to C++ with Rcpp which should also speed up the code.

like image 37
Vandenman Avatar answered Sep 03 '25 21:09

Vandenman