I've written this very simple Rust function:
fn iterate(nums: &Box<[i32]>) -> i32 {
let mut total = 0;
let len = nums.len();
for i in 0..len {
if nums[i] > 0 {
total += nums[i];
} else {
total -= nums[i];
}
}
total
}
I've written a basic benchmark that invokes the method with an ordered array and a shuffled one:
fn criterion_benchmark(c: &mut Criterion) {
const SIZE: i32 = 1024 * 1024;
let mut group = c.benchmark_group("Branch Prediction");
// setup benchmarking for an ordered array
let mut ordered_nums: Vec<i32> = vec![];
for i in 0..SIZE {
ordered_nums.push(i - SIZE/2);
}
let ordered_nums = ordered_nums.into_boxed_slice();
group.bench_function("ordered", |b| b.iter(|| iterate(&ordered_nums)));
// setup benchmarking for a shuffled array
let mut shuffled_nums: Vec<i32> = vec![];
for i in 0..SIZE {
shuffled_nums.push(i - SIZE/2);
}
let mut rng = thread_rng();
let mut shuffled_nums = shuffled_nums.into_boxed_slice();
shuffled_nums.shuffle(&mut rng);
group.bench_function("shuffled", |b| b.iter(|| iterate(&shuffled_nums)));
group.finish();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
I'm surprised that the two benchmarks have almost exactly the same runtime, while a similar benchmark in Java shows a distinct difference between the two, presumably due to branch prediction failure in the shuffled case.
I've seen mention of conditional move instructions, but if I otool -tv the executable (I'm running on a Mac), I don't see any in the iterate method output.
Can anyone shed light on why there's no perceptible performance difference between the ordered and the unordered cases in Rust?
Summary: LLVM was able to remove/hide the branch by using either the cmov instruction or a really clever combination of SIMD instructions.
I used Godbolt to view the full assembly (with -C opt-level=3). I will explain the important parts of the assembly below.
It starts like this:
mov r9, qword ptr [rdi + 8] ; r9 = nums.len()
test r9, r9 ; if len == 0
je .LBB0_1 ; goto LBB0_1
mov rdx, qword ptr [rdi] ; rdx = base pointer (first element)
cmp r9, 7 ; if len > 7
ja .LBB0_5 ; goto LBB0_5
xor eax, eax ; eax = 0
xor esi, esi ; esi = 0
jmp .LBB0_4 ; goto LBB0_4
.LBB0_1:
xor eax, eax ; return 0
ret
Here, the function differentiates between 3 different "states":
LBB0_4)LBB0_5)So let's take a look at the two different kinds of algorithms!
Remember that rsi (esi) and rax (eax) were set to 0 and that rdx is the base pointer to the data.
.LBB0_4:
mov ecx, dword ptr [rdx + 4*rsi] ; ecx = nums[rsi]
add rsi, 1 ; rsi += 1
mov edi, ecx ; edi = ecx
neg edi ; edi = -edi
cmovl edi, ecx ; if ecx >= 0 { edi = ecx }
add eax, edi ; eax += edi
cmp r9, rsi ; if rsi != len
jne .LBB0_4 ; goto LBB0_4
ret ; return eax
This is a simple loop iterating over all elements of num. In the loop's body there is a little trick though: from the original element ecx, a negated value is stored in edi. By using cmovl, edi is overwritten with the original value if that original value is positive. That means that edi will always turn out positive (i.e. contain the absolute value of the original element). Then it is added to eax (which is returned in the end).
So your if branch was hidden in the cmov instruction. As you can see in this benchmark, the time required to execute a cmov instruction is independent of the probability of the condition. It's a pretty amazing instruction!
The SIMD version consists of quite a few instructions that I won't fully paste here. The main loop handles 16 integers at once!
movdqu xmm5, xmmword ptr [rdx + 4*rdi]
movdqu xmm3, xmmword ptr [rdx + 4*rdi + 16]
movdqu xmm0, xmmword ptr [rdx + 4*rdi + 32]
movdqu xmm1, xmmword ptr [rdx + 4*rdi + 48]
They are loaded from memory into the registers xmm0, xmm1, xmm3 and xmm5. Each of those registers contains four 32 bit values, but to follow along more easily, just imagine each register contains exactly one value. All following instructions operate on each value of those SIMD registers individually, so that mental model is fine! My explanation below will also sound as if xmm registers would only contain a single value.
The main trick is now in the following instructions (which handle xmm5):
movdqa xmm6, xmm5 ; xmm6 = xmm5 (make a copy)
psrad xmm6, 31 ; logical right shift 31 bits (see below)
paddd xmm5, xmm6 ; xmm5 += xmm6
pxor xmm5, xmm6 ; xmm5 ^= xmm6
The logical right shift fills the "empty high-order bits" (the ones "shifted in" on the left) with the value of the sign bit. By shifting by 31, we end up with only the sign bit in every position! So any positive number will turn into 32 zeroes and any negative number will turn into 32 ones. So xmm6 is now either 000...000 (if xmm5 is positive) or 111...111 (if xmm5 is negative).
Next this artificial xmm6 is added to xmm5. If xmm5 was positive, xmm6 is 0, so adding it won't change xmm5. If xmm5 was negative, however, we add 111...111 which is equivalent to subtracting 1. Finally, we xor xmm5 with xmm6. Again, if xmm5 was positive in the beginning, we xor with 000...000 which does not have an effect. If xmm5 was negative in the beginning we xor with 111...111, meaning we flip all the bits. So for both cases:
add and xor didn't have any effect)So with these 4 instructions we calculated the absolute value of xmm5! Here again, there is no branch because of this bit-fiddling trick. And remember that xmm5 actually contains 4 integers, so it's quite speedy!
This absolute value is now added to an accumulator and the same is done with the three other xmm registers that contain values from the slice. (We won't discuss the remaining code in detail.)
If we allow LLVM to emit AVX2 instructions (via -C target-feature=+avx2), it can even use the pabsd instruction instead of the four "hacky" instructions:
vpabsd ymm2, ymmword ptr [rdx + 4*rdi]
It loads the values directly from memory, calculates the absolute and stores it in ymm2 in one instruction! And remember that ymm registers are twice as large as xmm registers (fitting eight 32 bit values)!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With