Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generator that yields True a fixed number of times at random intervals before exhausting

I want to iterate through a list of size n, with A 1s in it (the rest are 0s), and I want the 1s randomly distributed.
n and A are large, so I'm trying to make a generator instead of a list.
If they weren't super large, I might do indices=random.sample(range(n),A) and seq=[(i in indices) for i in range(n)].

My first thought was to do

from random import random
from math import log
n=355504839929
A=int(n*log(2,3))

def seq():
  for i in range(n):
    yield random()<A/n

On average this generator should yield True A out of n times. But I need it to be exactly A times.


Next I tried

def seq():
  ones=0
  for i in range(n):
    if ones<A:
      val=random()<A/n
      ones+=val
      yield val
    else:
      yield False

This way once it has produced A 1s it just yields False. But it's still possible to output fewer than A 1s.


Another thought I had is there are "n choose A" ways to arrange A 1s in the list, so I could choose a random value from range(math.comb(n,A)), and then write some generator that yields Trues according to the given combination. But if n is huge then "n choose A" is probably super huge and maybe not feasible to calculate.


Is there a way to do this?

like image 998
WeCanDoItGuys Avatar asked Oct 24 '25 03:10

WeCanDoItGuys


1 Answers

You’re close, but you forgot to manage the slots left.

  • When ones left = slots left, the chance is 100%.

  • When ones left = 0, the chance is 0%.

So the rule is: spread A ones into n slots by giving each slot a probability = ones left ÷ slots left. This guarantees exactly A ones in random positions.

My test:

import random
from math import log

def seq(n, A):
    remaining_ones = A
    remaining_slots = n
    for _ in range(n):
        if random.random() < remaining_ones / remaining_slots:
            yield True
            remaining_ones -= 1
        else:
            yield False
            
        remaining_slots -= 1

n=10000
A=int(n*log(2,3))

print(A == sum(seq(n,A)))
like image 127
Viet Dinh Avatar answered Oct 27 '25 01:10

Viet Dinh



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!