I'm just starting to use JAX, and I wonder—what would be the right way to implement if-then-elif-then-else in JAX/Python? For example, given input arrays: n = [5, 4, 3, 2] and k = [3, 3, 3, 3], I need to implement the following pseudo-code:
def n_choose_k_safe(n, k):
r = jnp.empty(4)
for i in range(4):
if n[i] < k[i]:
r[i] = 0
elif n[i] == k[i]:
r[i] = 1
else:
r[i] = func_nchoosek(n[i], k[i])
return r
There are so many choices like vmap, lax.select, lax.where, jax.cond, lax.fori_loop, etc., so that it is hard to decide on specific combinations of the utilities to use. By the way, k can be a scalar (if that makes it simpler).
There's a slightly more compact way to express the solution in Valentin's answer, using jax.numpy.select:
def n_choose_k_safe(n, k):
return jnp.select(condlist=[n > k, n == k],
choicelist=[jnp.vectorize(func_nchoosek)(n, k), 1],
default=0)
For input arrays of length 4, this should return the same result as your original code, assuming func_nchoosek is compatible with jax.vmap. Using vectorize here in place of vmap will make the function also compatible with scalar inputs for k, without having to manually set the in_axes argument.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With