Suppose I have a data frame as follows:
> foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
> foo
  x id
1 1  1
2 2  1
3 3  2
4 4  2
5 5  2
6 6  3
7 7  3
8 8  3
9 9  3
I want a very efficient implementation of h(a, b) that computes sums all (a - xi)*(b - xj) for xi, xj belonging to the same id class. For example, my current implementation is
h(a, b, foo){
  a.diff = a - foo$x
  b.diff = b - foo$x
  prod = a.diff%*%t(b.diff)
  id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + diag(nrow(foo))
  return(sum(prod*id.indicator))
}
For example, with (a, b) = (0, 1), here is the output from each step in the function
> a.diff
[1] -1 -2 -3 -4 -5 -6 -7 -8 -9
> b.diff
[1]  0 -1 -2 -3 -4 -5 -6 -7 -8
> prod
      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
 [1,]    0    1    2    3    4    5    6    7    8
 [2,]    0    2    4    6    8   10   12   14   16
 [3,]    0    3    6    9   12   15   18   21   24
 [4,]    0    4    8   12   16   20   24   28   32
 [5,]    0    5   10   15   20   25   30   35   40
 [6,]    0    6   12   18   24   30   36   42   48
 [7,]    0    7   14   21   28   35   42   49   56
 [8,]    0    8   16   24   32   40   48   56   64
 [9,]    0    9   18   27   36   45   54   63   72
> id.indicator
  1 2 3 4 5 6 7 8 9
1 1 1 0 0 0 0 0 0 0
2 1 1 0 0 0 0 0 0 0
3 0 0 1 1 1 0 0 0 0
4 0 0 1 1 1 0 0 0 0
5 0 0 1 1 1 0 0 0 0
6 0 0 0 0 0 1 1 1 1
7 0 0 0 0 0 1 1 1 1
8 0 0 0 0 0 1 1 1 1
9 0 0 0 0 0 1 1 1 1
In reality, there can be up to 1000 id clusters, and each cluster will be at least 40, making this method too inefficient because of the sparse entries in id.indicator and extra computations in prod on the off-block-diagonals which won't be used.
I played a round a bit. First, your implementation:
foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
h <- function(a, b, foo){
  a.diff = a - foo$x
  b.diff = b - foo$x
  prod = a.diff%*%t(b.diff)
  id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + 
     diag(nrow(foo))
  return(sum(prod*id.indicator))
}
h(a = 1, b = 0, foo = foo)
#[1] 891
Next, I tried a variant using a proper sparse matrix implementation (via the Matrix package) and functions for the index matrix. I also use tcrossprod which I often find to be a bit faster than a %*% t(b).
library("Matrix")
h2 <- function(a, b, foo) {
  a.diff <- a - foo$x
  b.diff <- b - foo$x
  prod <- tcrossprod(a.diff, b.diff) # the same as a.diff%*%t(b.diff)
  id.indicator <- do.call(bdiag, lapply(table(foo$id), function(n) matrix(1,n,n)))
  return(sum(prod*id.indicator))
}
h2(a = 1, b = 0, foo = foo)
#[1] 891
Note that this function relies on foo$id being sorted.
Lastly, I tried avoid creating the full n by n matrix.
h3 <- function(a, b, foo) {
  a.diff <- a - foo$x
  b.diff <- b - foo$x
  ids <- unique(foo$id)
  res <- 0
  for (i in seq_along(ids)) {
    indx <- which(foo$id == ids[i])
    res <- res + sum(tcrossprod(a.diff[indx], b.diff[indx]))
  }
  return(res)
}
h3(a = 1, b = 0, foo = foo)
#[1] 891
Benchmarking on your example:
library("microbenchmark")
microbenchmark(h(a = 1, b = 0, foo = foo), 
               h2(a = 1, b = 0, foo = foo),
               h3(a = 1, b = 0, foo = foo))
# Unit: microseconds
#                        expr      min        lq      mean    median        uq       max neval
#  h(a = 1, b = 0, foo = foo)  248.569  261.9530  493.2326  279.3530  298.2825 21267.890   100
# h2(a = 1, b = 0, foo = foo) 4793.546 4893.3550 5244.7925 5051.2915 5386.2855  8375.607   100
# h3(a = 1, b = 0, foo = foo)  213.386  227.1535  243.1576  234.6105  248.3775   334.612   100
Now, in this example, the h3 is the fastest and h2 is really slow. But I guess that both will be faster for larger examples. Probably, h3 will still win for larger examples though. While there is plenty of room of more optimization, h3 should be faster and more memory efficient. So, I think you should go for a variant of h3 which does not create unnecessarily large matrices.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With