Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding Branch Prediction Optimizations in Python

I was recently working on a simple Leetcode problem to write a function that returns True if a list of numbers contains a duplicate, and False otherwise. The code that I submitted was so simple that I was sure it must be the fastest solution, but it came back with a ~40th percentile of performance. I looked at one of the faster solutions, and the only difference was an else that I had omitted because the function was returning in the corresponding if, meaning logically the else was not necessary. Benchmarking this code locally, I can see a consistent speed improvement of about 20% with the else in place, but only under certain (random) test cases. See the script below for some test cases that do or do not show this improvement.

import random
import time

def containsDuplicate1(nums):
    seen = set([])
    for num in nums:
        if num in seen:
            return True
        seen.add(num)
    return False

def containsDuplicate2(nums):
    seen = set([])
    for num in nums:
        if num in seen:
            return True
        else:
            seen.add(num)
    return False

# When duplicates are rare or non-existent, we get no difference
time1 = 0
time2 = 0
for i in range(10000):
    nums = list(range(1000))
    start = time.time()
    containsDuplicate1(nums)
    time1 += (time.time() - start)
    start = time.time()
    containsDuplicate2(nums)
    time2 += (time.time() - start)
print("Test case 1: No duplicates")
print(f"Execution time for 'no else clause':    {time1}")
print(f"Execution time for 'else clause':       {time2}")
print(f"Percentage speed improvement:           {100 * (time1 - time2) / time1:.2f}%")

# When the numbers are random but are the same on every run, the difference is negligible (easier to predict?)
time1 = 0
time2 = 0
nums = [random.randint(0, 1000) for _ in range(1000)]
for i in range(10000):
    start = time.time()
    containsDuplicate1(nums)
    time1 += (time.time() - start)
    start = time.time()
    containsDuplicate2(nums)
    time2 += (time.time() - start)
print("Test case 2: Random numbers, but same every time")
print(f"Execution time for 'no else clause':    {time1}")
print(f"Execution time for 'else clause':       {time2}")
print(f"Percentage speed improvement:           {100 * (time1 - time2) / time1:.2f}%")

# But when numbers are random and change on every run (with frequent duplicates), the difference is significant
time1 = 0
time2 = 0
for i in range(10000):
    nums = [random.randint(0, 1000) for _ in range(1000)]
    start = time.time()
    containsDuplicate1(nums)
    time1 += (time.time() - start)
    start = time.time()
    containsDuplicate2(nums)
    time2 += (time.time() - start)
print("Test case 3: Random numbers, different every time")
print(f"Execution time for 'no else clause':    {time1}")
print(f"Execution time for 'else clause':       {time2}")
print(f"Percentage speed improvement:           {100 * (time1 - time2) / time1:.2f}%")

Output:

Test case 1: No duplicates
Execution time for 'no else clause':    0.19391775131225586
Execution time for 'else clause':       0.19402170181274414
Percentage speed improvement:           -0.05%
Test case 2: Random numbers, but same every time
Execution time for 'no else clause':    0.0033860206604003906
Execution time for 'else clause':       0.0034241676330566406
Percentage speed improvement:           -1.13%
Test case 3: Random numbers, different every time
Execution time for 'no else clause':    0.02005624771118164
Execution time for 'else clause':       0.016332626342773438
Percentage speed improvement:           18.57%

I'm guessing that this improvement is something to do with the CPU's branch prediction, but I have a few questions about that:

  1. When I generate the bytecode for both of these functions, it looks identical. I don't know enough about python compilation/execution to understand when the source code would be used instead of the bytecode, but it must be referring to the source code in some way to explain the difference when the bytecode is identical.
  2. I understand why test case 1 would not benefit from branch prediction, since there is never a duplicate and we always miss the if condition. I assume that test case 2 is repetitive enough that the compiler or processor spots a pattern and we no longer benefit from branch prediction. But for test case 3, where the inputs are random every time, I don't really see how branch prediction helps. Is it the case that in the first example, the lack of an else prevents branch prediction from happening at all? Or is the first version always starting on the set operation that must be discarded sometimes because it can't predict the branching? I'm just looking for some intuition around what might be going on to cause the consistent improvement for test case 3.
like image 329
tgarvz Avatar asked Jun 24 '26 13:06

tgarvz


2 Answers

Your script gives me similar results, consistently about ~26% speed difference in the third case. Also when I only call one of the functions both times. I still don't know why, but I eliminated some potential culprits.

I run the same function ten times for each input and get results like these (with CPython 3.13):

64.5 ms
54.6 ms  (-15.35%)
51.9 ms  (-4.90%)
50.7 ms  (-2.42%)
49.6 ms  (-2.15%)
49.3 ms  (-0.59%)
49.5 ms  (+0.37%)
49.5 ms  (-0.01%)
49.2 ms  (-0.54%)
49.1 ms  (-0.20%)

The first to fourth call fairly consistently get faster by about 15%, 5%, and 2%. After that, there's no clear pattern.

Things I did:

  • Use time.perf_counter for better timing.
  • Disable the garbage collector, so it doesn't interfere.
  • Use just one function, as your two functions aren't what makes the difference.
  • Pre-generate all inputs, so their generation doesn't interfere.
  • Sort the list of inputs by index of the first duplicate, so that consecutive inputs are similar. Because the first call for each input could have the disadvantage of something like branch prediction still being tuned to the previous input, so I want them similar.
  • Iterate each input before measurements, to get it into caches.

And despite all that, I still see my speedups shown above. But without my for _ in nums: pass, I still get about 26% speedup from first to second call, instead of 15%. So caching in the data seems to indeed play a role. The remaining 15% speedup? Maybe it indeed has to do with branch prediction, though not like you thought because of your different Python source codes but at a deeper level, insider the interpreter.

My script (Attempt This Online!):

import random
from time import perf_counter as time
import gc

gc.disable()

def containsDuplicate(nums):
    seen = set([])
    for num in nums:
        if num in seen:
            return True
        seen.add(num)
    return False

def duplicateIndex(nums):
    seen = set([])
    for num in nums:
        if num in seen:
            break
        seen.add(num)
    return len(seen)

m = 10
times = [0] * m

inputs = [
    [random.randint(0, 1000) for _ in range(1000)]
    for i in range(10000)
]
inputs.sort(key=duplicateIndex)

for nums in inputs:

    for _ in nums:
        pass

    for i in range(m):
        start = time()
        containsDuplicate(nums)
        times[i] += time() - start

prev = None
for t in times:
    print(f"{t * 1e3 :.1f} ms" + (f"  ({(t - prev) / prev :+.2%})" if prev else ''))
    prev = t
like image 200
no comment Avatar answered Jun 27 '26 02:06

no comment


def containsDuplicate1(nums): ...

def containsDuplicate2(nums): ...

In cPython 3.13.1, these are identical. Compare the dis output for both of them to see that the slightly different source codes produced identical bytecodes.

Benchmarking is hard. Your benching setup produced an artifact. The two functions did not start in the identical state: one needed an extra GC or found different data populating the L3 cache.

Please seed() your PRNG, as an aid to reproducibility. And follow PEP 8's naming advice.


This is an odd assignment:
seen = set([]). Prefer
seen = set(), as a zero-length iterable has no effect.

Take advantage of that iterable. This runs at compiled ("C") speed, rather than at interpreted bytecode speed:

def contains_duplicate3(nums: list[float]) -> bool:
    seen = set(nums)
    return len(seen) < len(nums)

This, alas, does not have an early exit, so version 1 will beat it for certain workloads. Consider having version 1 define a chunk_size so it will add maybe a hundred elements at a time, bailing early if that chunk introduced a duplicate. Play with the size setting to find the sweet spot that works well for your favorite workload. It's related to the ratio of compiled speed to interpreted speed.


Timsort is very fast. If you're willing to sort nums in place, or produce a sorted copy of it, you might prefer the simplicity of walking the sorted list and comparing current entry against previous entry. Though again, "early exit" is attractive, it can beat this "unconditionally sort everything" approach.

like image 38
J_H Avatar answered Jun 27 '26 01:06

J_H



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!