Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multiple `vmap` in JAX?

This may me a very simple thing, but I was wondering how to perform mapping in the following example.

Suppose we have a function that we want to evaluate derivative with respect to xt, yt and zt, but it also takes additional parameters xs, ys and zs.

import jax.numpy as jnp
from jax import grad, vmap

def fn(xt, yt, zt, xs, ys, zs):
    return jnp.sqrt((xt - xs) ** 2 + (yt - ys) ** 2 + (zt - zs) ** 2)

Now, let us define the input data:

xt = jnp.array([1., 2., 3., 4.])
yt = jnp.array([1., 2., 3., 4.])
zt = jnp.array([1., 2., 3., 4.])
xs = jnp.array([1., 2., 3.])
ys = jnp.array([3., 3., 3.])
zs = jnp.array([1., 1., 1.])

In order to evaluate gradient for each pair of data points in xt, yt and zt, I have to do the following:

fn_prime = vmap(grad(fn, argnums=(0, 1, 2)), in_axes=(None, None, None, 0, 0, 0))

a = []
for _xt in xt:
    for _yt in yt:
        for _zt in zt:
            a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))

and it results in a list of tuples. Once the list is converted to a jnp.array, it is of the following shape:

a = jnp.array(a)
print(f`shape = {a.shape}')
shape = (64, 3, 3)

My question is: Is there a way to avoid this for loop and evaluate all gradients in the same sweep?

like image 656
antelk Avatar asked Jan 21 '26 07:01

antelk


1 Answers

A good rule of thumb for cases like this is that each nested for loop translates to a nested vmap over an appropriate in_axis. With this in mind, you can re-express your computation this way:

def f_loops(xt, yt, zt, xs, ys, zs):
  a = []
  for _xt in xt:
    for _yt in yt:
      for _zt in zt:
        a.append(fn_prime(_xt, _yt, _zt, xs, ys, zs))
  return jnp.array(a)

def f_vmap(xt, yt, zt, xs, ys, zs):
  f_z = vmap(fn_prime, in_axes=(None, None, 0, None, None, None))
  f_yz = vmap(f_z, in_axes=(None, 0, None, None, None, None))
  f_xyz = vmap(f_yz, in_axes=(0, None, None, None, None, None))
  return jnp.stack(f_xyz(xt, yt, zt, xs, ys, zs), axis=3).reshape(64, 3, 3)

out_loops = f_loops(xt, yt, zt, xs, ys, zs)
out_vmap = f_vmap(xt, yt, zt, xs, ys, zs)

np.testing.assert_allclose(out_loops, out_vmap)  # passes
like image 99
jakevdp Avatar answered Jan 22 '26 21:01

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!