Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

vmap in Jax to loop over arguments

Tags:

python

arrays

jax

Lets suppose I have some function which returns a sum of inputs.

@jit
def some_func(a,r1,r2):
    return a + r1 + r2

Now I would like to loop over different values of r1 and r2, save the result and add it to a counter. This is what I mean:

a = 0 
r1 = jnp.arange(0,3)
r2 = jnp.arange(0,3)


s = 0 
for i in range(len(r1)): 
    for j in range(len(r2)): 
        s+= some_func(a, r1[i], r2[j])
    
print(s)
DeviceArray(18, dtype=int32)

My question is, how do I do this with jax.vmap to avoid writing the for loops? I have something like this so far:

vmap(some_func, in_axes=(None, 0,0), out_axes=0)(jnp.arange(0,3), jnp.arange(0,3))

but this gives me the following error:

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0, 0) for value tree PyTreeDef((*, *)).

I have a feeling that the error is in in_axes but I am not sure how to get vmap to pick a value for r1 loop over r2 and then do the same for all r1 whilst saving intermediate results.

Any help is appreciated.

like image 775
Zohim Avatar asked Oct 25 '25 04:10

Zohim


1 Answers

vmap will map over a single axis at a time. Because you want to map over two different axes, you'll need two vmap calls:

func_mapped = vmap(vmap(some_func, (None, 0, None)), (None, None, 0))
func_mapped(a, r1, r2).sum()
# 18

Alternatively, for a simple function like this you can avoid vmap and use numpy-style broadcasting to get the same result:

some_func(a, r1[None, :, None], r2[None, None, :]).sum()
# 18
like image 178
jakevdp Avatar answered Oct 26 '25 20:10

jakevdp