Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How is np.repeat so fast?

I am implementing the Poisson bootstrap in Rust and wanted to benchmark my repeat function against numpy's. Briefly, repeat takes in two arguments, data and weight, and repeats each element of data by the weight, e.g. [1, 2, 3], [1, 2, 0] -> [1, 2, 2]. My naive version was around 4.5x slower than np.repeat.

pub fn repeat_by(arr: &[f64], repeats: &[u64]) -> Vec<f64> {
    // Use flat_map to create a single iterator of all repeated elements
    let result: Vec<f64> = arr
        .iter()
        .zip(repeats.iter())
        .flat_map(|(&value, &count)| std::iter::repeat_n(value, count as usize))
        .collect();

    result
}

I also tried a couple of more versions, e.g. one where I pre-allocated a vector with the necessary capacity, but all performed similarly.

While doing more investigating though, I found that np.repeat is actually way faster than other numpy functions that I expected to perform similarly. For example, we can build a list of indices and use numpy slicing / take to perform the same operation as np.repeat. However, doing this (and even removing the list construction from the timings), np.repeat is around 3x faster than numpy slicing / take.

import timeit

import numpy as np

N_ROWS = 100_000
x = np.random.rand(N_ROWS)

weight = np.random.poisson(1, len(data))


# pre-compute the indices so slow python looping doesn't affect the timing
indices = []
for w in weight:
    for i in range(w):
        indices.append(i)


print(timeit.timeit(lambda: np.repeat(x, weight), number=1_000))  # 0.8337333500003297
print(timeit.timeit(lambda: np.take(x, indices), number=1_000))  # 3.1320624930012855

My C is not so good, but it seems like the relevant implementation is here: https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/item_selection.c#L785. It would be amazing if someone could help me understand at a high level what this code is doing--on the surface, it doesn't look like anything particularly special (SIMD, etc.), and looks pretty similar to my naive Rust version (memcpy vs repeat_n). In addition, I am struggling to understand why it performs so much better than even numpy slicing.

like image 768
stressed Avatar asked Aug 31 '25 03:08

stressed


1 Answers

TL;DR: the gap is certainly due to the use of wider loads/stores in Numpy than your Rust code, and you should avoid indexing if you can for sake of performance.


Performance of the Numpy code VS your Rust code

First of all, we can analyse the assembly code generated from your Rust code (I am not very familiar with Rust but I am with assembly). The generated code is quite big, but here is the main part (see it on Godbolt):

example::repeat_by::hf03ad1ea376407dc:
        push    rbp
        push    r15
        push    r14
        push    r13
        push    r12
        push    rbx
        sub     rsp, 72
        mov     r12, rdx
        cmp     r8, rdx
        cmovb   r12, r8
        test    r12, r12
        je      .LBB2_4
        mov     r14, rcx
        mov     r15, r12
        neg     r15
        mov     ebx, 1
.LBB2_2:
        mov     r13, qword ptr [r14 + 8*rbx - 8]
        test    r13, r13
        jne     .LBB2_5
        lea     rax, [r15 + rbx]
        inc     rax
        inc     rbx
        cmp     rax, 1
        jne     .LBB2_2
.LBB2_4:
        mov     qword ptr [rdi], 0
        mov     qword ptr [rdi + 8], 8
        mov     qword ptr [rdi + 16], 0
        jmp     .LBB2_17
.LBB2_5:
        mov     qword ptr [rsp + 48], rsi
        mov     qword ptr [rsp + 56], rdi
        cmp     r13, 5
        mov     ebp, 4
        cmovae  rbp, r13
        lea     rcx, [8*rbp]
        mov     rax, r13
        shr     rax, 61
        jne     .LBB2_6
        mov     qword ptr [rsp + 8], 0
        movabs  rax, 9223372036854775800
        cmp     rcx, rax
        ja      .LBB2_7
        mov     rax, qword ptr [rsp + 48]
        mov     rax, qword ptr [rax + 8*rbx - 8]
        mov     qword ptr [rsp + 16], rax
        mov     rax, qword ptr [rip + __rust_no_alloc_shim_is_unstable@GOTPCREL]
        movzx   eax, byte ptr [rax]
        mov     eax, 8
        mov     qword ptr [rsp + 8], rax
        mov     esi, 8
        mov     rdi, rcx
        mov     qword ptr [rsp + 64], rcx
        call    qword ptr [rip + __rust_alloc@GOTPCREL]
        mov     rcx, qword ptr [rsp + 64]
        test    rax, rax
        je      .LBB2_7
        mov     rcx, qword ptr [rsp + 16]
        mov     qword ptr [rax], rcx
        mov     qword ptr [rsp + 24], rbp
        mov     qword ptr [rsp + 32], rax
        mov     qword ptr [rsp + 40], 1
        mov     ebp, 1
        jmp     .LBB2_11
.LBB2_22:
        mov     rcx, qword ptr [rsp + 16]
        mov     qword ptr [rax + 8*rbp], rcx
        inc     rbp
        mov     qword ptr [rsp + 40], rbp
.LBB2_11:
        dec     r13
        je      .LBB2_12
        cmp     rbp, qword ptr [rsp + 24]
        jne     .LBB2_22
.LBB2_20:
        lea     rdi, [rsp + 24]
        mov     rsi, rbp
        mov     rdx, r13
        call    alloc::raw_vec::RawVecInner<A>::reserve::do_reserve_and_handle::hd90f8297b476acb7
        mov     rax, qword ptr [rsp + 32]
        jmp     .LBB2_22
.LBB2_12:
        cmp     rbx, r12
        jae     .LBB2_16
        inc     rbx
.LBB2_14:
        mov     r13, qword ptr [r14 + 8*rbx - 8]
        test    r13, r13
        jne     .LBB2_18
        lea     rcx, [r15 + rbx]
        inc     rcx
        inc     rbx
        cmp     rcx, 1
        jne     .LBB2_14
        jmp     .LBB2_16
.LBB2_18:
        mov     rcx, qword ptr [rsp + 48]
        mov     rcx, qword ptr [rcx + 8*rbx - 8]
        mov     qword ptr [rsp + 16], rcx
        cmp     rbp, qword ptr [rsp + 24]
        jne     .LBB2_22
        jmp     .LBB2_20
.LBB2_16:
        mov     rax, qword ptr [rsp + 40]
        mov     rdi, qword ptr [rsp + 56]
        mov     qword ptr [rdi + 16], rax
        movups  xmm0, xmmword ptr [rsp + 24]
        movups  xmmword ptr [rdi], xmm0
.LBB2_17:
        mov     rax, rdi
        add     rsp, 72
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     r15
        pop     rbp
        ret
.LBB2_6:
        mov     qword ptr [rsp + 8], 0
.LBB2_7:
        lea     rdx, [rip + .L__unnamed_2]
        mov     rdi, qword ptr [rsp + 8]
        mov     rsi, rcx
        call    qword ptr [rip + alloc::raw_vec::handle_error::h5290ea7eaad4c986@GOTPCREL]
        mov     rbx, rax
        mov     rsi, qword ptr [rsp + 24]
        test    rsi, rsi
        je      .LBB2_25
        mov     rdi, qword ptr [rsp + 32]
        shl     rsi, 3
        mov     edx, 8
        call    qword ptr [rip + __rust_dealloc@GOTPCREL]
.LBB2_25:
        mov     rdi, rbx
        call    _Unwind_Resume@PLT

We can see there there is only a single use of SIMD (xmm, ymm or zmm) registers and it is not in a loop. There is also no call to memcpy. This means the Rust computation is certainly not vectorised using SIMD instructions. The loops seems to move at best 64-bit items. The SSE (SIMD) instruction set can move 128-bit vectors and the AVX (SIMD) one can move 256-bit one (512-bit for AVX-512 supported only on few recent PC CPUs and most recent server ones). As a result, the rust code is certainly sub-optimal because the Rust code performs scalar moves.

On the other hand, Numpy basically calls memcpy in nested loops (in the linked code) as long as needs_custom_copy is false, which is I think the case for all basic contiguous native arrays like the one computed in your code (i.e. no pure-Python objects in the array). memcpy is generally aggressively optimized so it benefit from SIMD instructions on platforms where it worth it. For very small copies, it can be slower than scalar moves though (due to the call and sometimes some checks).

I expect the Rust code to be about 4 times slower than Numpy on a CPU supporting AVX-2 (assuming the target CPU actually supports a 256-bit-wide data-path, which is AFAIK the case on relatively recent mainstream CPUs) as long as the size of the copied slices is rather big (e.g. at least few dozens of double-precision items).

Put it shortly, the gap is certainly due to the (indirect) use of wide SIMD load/store in Numpy as opposed to the Rust code using less-efficient scalar load/stores.


Performance of np.repeat VS np.take

I found that np.repeat is actually way faster than other numpy functions that I expected to perform similarly. [...] np.repeat is around 3x faster than numpy slicing / take.

Regarding np.take it is more expensive because it cannot really benefit from SIMD instructions and Numpy also needs to read the indices from memory. To be more precise, on x86-64 CPU, AVX-2 and AVX-512 support gather instructions to do that but they are not so fast compared to scalar loads (possibly even slower regarding the actual target micro-architecture of the CPU). For example, on AMD Zen+/Zen2/Zen3/Zen4 CPUs, gather instructions does not worth it (not faster), mainly because the underlying hardware implementation is not efficient yet (micro-coded). On relatively-recent Intel CPUs supporting AVX-2, gather instructions are a bit faster, especially for 32-bit items and 32-bit addresses -- it does not really worth it for 64-bit ones (which is your use-case). On Intel CPUs supporting AVX-512 (mainly IceLake CPU and server-side CPUs), it worth it for both 32-bit and 64-bit items. x86-64 CPUs not supporting AVX-2 (i.e. old ones) do not support gather instructions. Even the best (x86-64) gather instruction implementation cannot compete with (256-bit or 512-bit) packed loads/stores typically done by memcpy in np.repeat on wide slices, simply because all mainstream CPUs perform scalar loads (i.e. <=64-bit) internally saturating load ports. Some memcpy implementations use rep movsb which is very well optimised on quite-recent x86-64 CPUs (so to adapt the granularity of load-store regarding the use-case and even use streaming stores if needed on wide arrays).

Even on GPUs (having an efficient gather implementation), gather instructions are still generally more expensive than packed loads. They are at best equally fast, but one need to consider the overhead of reading also indices from memory so it can never be faster.

Put it shortly, you should avoid indexing if you can since it is not very SIMD-friendly.

like image 131
Jérôme Richard Avatar answered Sep 02 '25 16:09

Jérôme Richard