Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to vectorize a function over a list of unequal length arrays in JAX

This is a minimal example of the real larger problem I am facing. Consider the function below:

import jax.numpy as jnp
def test(x):
    return jnp.sum(x)

I tried to vectorize it by:

v_test = jax.vmap(test)

My inputs to test look like:

x1 = jnp.array([1,2,3])
x2 = jnp.array([4,5,6,7])
x3 = jnp.array([8,9])
x4 = jnp.array([10])

and my input to v_test is:

x = [x1, x2, x3, x4]

If I try:

v_test(x)

I get the error below:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([3, 4, 2, 1],)

Is there a way to vectorize test over a list of unequal length arrays? I could avoid this by padding so the arrays have the same length, however, padding is not desired.

like image 901
MOON Avatar asked Nov 02 '25 07:11

MOON


1 Answers

JAX does not support ragged arrays, (i.e. arrays in which each row has a different number of elements) so there is currently no way to use vmap for this kind of data. Your best bet is probably to use a Python for loop:

y = [test(xi) for xi in x]

Alternatively, you might be able to express the operation you have in mind in terms of segment_sum or similar operations. For example:

segments = jnp.concatenate([i * jnp.ones_like(xi) for i, xi in enumerate(x)])
result = jax.ops.segment_sum(jnp.concatenate(x), segments)
print(result)
# [ 6 22 17 10]

Another possibility is to pad the input arrays so that they can fit into a standard, non-ragged 2D array.

like image 137
jakevdp Avatar answered Nov 04 '25 23:11

jakevdp