Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there an efficient way to determine if a sum of floats will be order invariant?

Due to precision limitations in floating point numbers, the order in which numbers are summed can affect the result.

>>> 0.3 + 0.4 + 2.8
3.5
>>> 2.8 + 0.4 + 0.3
3.4999999999999996

This small error can become a bigger problem if the results are then rounded.

>>> round(0.3 + 0.4 + 2.8)
4
>>> round(2.8 + 0.4 + 0.3)
3

I would like to generate a list of random floats such that their rounded sum does not depend on the order in which the numbers are summed. My current brute force approach is O(n!). Is there a more efficient method?

import random
import itertools
import math


def gen_sum_safe_seq(func, length: int, precision: int) -> list[float]:
    """
    Return a list of floats that has the same sum when rounded to the given
    precision regardless of the order in which its values are summed.
    """
    invalid = True
    while invalid:
        invalid = False
        nums = [func() for _ in range(length)]
        first_sum = round(sum(nums), precision)
        for p in itertools.permutations(nums):
            if round(sum(p), precision) != first_sum:
                invalid = True
                print(f"rejected {nums}")
                break

    return nums

for _ in range(3):
    nums = gen_sum_safe_seq(
        func=lambda :round(random.gauss(3, 0.5), 3),
        length=10,
        precision=2,
    )
    print(f"{nums} sum={sum(nums)}")

For context, as part of a programming exercise I'm providing a list of floats that model a measured value over time to ~1000 entry-level programming students. They will sum them in a variety of ways. Provided that their code is correct, I'd like for them all to get the same result to simplify checking their code. I do not want to introduce the complexities of floating point representation to students at this level.

like image 943
John Cole Avatar asked Dec 12 '25 11:12

John Cole


2 Answers

Not that I know of, but a practical approach is to use math.fsum() instead. While some platforms are perverse nearly beyond repair, on most platforms fsum() returns the infinitely-precise result subject to a single rounding error at the end. Which means the final result is independent of the order in which elements are given. For example,

>>> from math import fsum
>>> from itertools import permutations
>>> for p in permutations([0.3, 0.4, 2.8]):
...     print(p, fsum(p))
(0.3, 0.4, 2.8) 3.5
(0.3, 2.8, 0.4) 3.5
(0.4, 0.3, 2.8) 3.5
(0.4, 2.8, 0.3) 3.5
(2.8, 0.3, 0.4) 3.5
(2.8, 0.4, 0.3) 3.5

Python's fsum() docs go on to point to slower ways that are more robust against perverse platform quirks.

Arguably silly

Here's another approach: fiddle the numbers you generate, clearing enough low-order bits so that no rounding of any kind is ever needed no matter how an addition tree is arranged. I haven't thought hard about this - it's not worth the effort ;-) For a start, I haven't thought about negative inputs at all.

def crunch(xs):
    from math import floor, ulp, ldexp
    if any(x < 0.0 for x in xs):
        raise ValueError("all elements must be >= 0.0")
    target_ulp = ldexp(ulp(max(xs)), len(xs).bit_length())
    return [floor(x / target_ulp) * target_ulp
            for x in xs]

Then, e.g.,

>>> xs = crunch([0.3, 0.4, 2.8])
>>> for x in xs:
...     print(x, x.hex())
0.29999999999999893 0x1.3333333333320p-2
0.3999999999999986 0x1.9999999999980p-2
2.799999999999999 0x1.6666666666664p+1

The decimal values are "a mess", because, from the hex values, you can see that the binary values reliably have enough low-order 0 bits to absorb any shifts that may be needed during a sum. The order of summation makes no difference then:

>>> for p in permutations(xs):
...     print(p, sum(p))
(0.29999999999999893, 0.3999999999999986, 2.799999999999999) 3.4999999999999964
(0.29999999999999893, 2.799999999999999, 0.3999999999999986) 3.4999999999999964
(0.3999999999999986, 0.29999999999999893, 2.799999999999999) 3.4999999999999964
(0.3999999999999986, 2.799999999999999, 0.29999999999999893) 3.4999999999999964
(2.799999999999999, 0.29999999999999893, 0.3999999999999986) 3.4999999999999964
(2.799999999999999, 0.3999999999999986, 0.29999999999999893) 3.4999999999999964

and

>>> import random, math
>>> xs = [random.random() * 1e3 for i in range(100_000)]
>>> sum(xs)
49872035.43787267
>>> math.fsum(xs) # different
49872035.43787304
>>> sum(sorted(xs, reverse=True)) # and different again
49872035.43787266
>>> ys = crunch(xs) # now fiddle the numbers
>>> sum(ys)  # and all three ways are the same
49872035.43712826
>>> math.fsum(ys)
49872035.43712826
>>> sum(sorted(ys, reverse=True))
49872035.43712826

The good news is that this is obviously linear-time in the number of inputs. The bad news is that more and more trailing bits have to be thrown away, the higher the dynamic range across the inputs, and the more inputs there are.

like image 139
Tim Peters Avatar answered Dec 16 '25 00:12

Tim Peters


Faster (0.3 seconds instead of your 8 seconds for length 10, and 3.4 seconds for length 12) and considers more ways to sum (not just linear like ((a+b)+c)+d, but also divide&conquer summation like (a+b)+(c+d)).

The core part is the sums function, which computes all possible sums. First it enumerates the numbers, so it can use sets without losing duplicate numbers. Then its inner helper sums does the actual work. It tries all possible splits of the given numbers into a left subset and a right subset, computes all possible sums for each, and combines them.

import random
import itertools
import math
import functools

def sums(nums):
    @functools.cache
    def sums(nums):
        if len(nums) == 1:
            [num] = nums
            return {num[1]}
        result = set()
        for k in range(1, len(nums)):
            for left in map(frozenset, itertools.combinations(nums, k)):
                right = nums - left
                left_sums = sums(left)
                right_sums = sums(right)
                for L in left_sums:
                    for R in right_sums:
                        result.add(L + R)
        return result
    return sums(frozenset(enumerate(nums)))
    
def gen_sum_safe_seq(func, length: int, precision: int) -> list[float]:
    """
    Return a list of floats that has the same sum when rounded to the given
    precision regardless of the order in which its values are summed.
    """
    while True:
        nums = [func() for _ in range(length)]
        rounded_sums = {
            round(s, precision)
            for s in sums(nums)
        }
        if len(rounded_sums) == 1:
            return nums
        print(f"rejected {nums}")

for _ in range(3):
    nums = gen_sum_safe_seq(
        func=lambda :round(random.gauss(3, 0.5), 3),
        length=10,
        precision=2,
    )
    print(f"{nums} sum={sum(nums)}")

Attempt This Online!

like image 35
Kelly Bundy Avatar answered Dec 15 '25 22:12

Kelly Bundy



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!