Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Modify an array from indexes contained in another array

I have an array of the shape (2,10) such as:

arr = jnp.ones(shape=(2,10)) * 2

or

[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]

and another array, for example [2,4].

I want the second array to tell from which index the elements of arr should be masked. Here the result would be:

[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
 [2. 2. 2. 2.  -1. -1. -1. -1. -1. -1.]]

I need to use jax.numpy and the answer to be vectorized and fast if possible, i.e. not using loops.

like image 275
Valentin Macé Avatar asked Dec 28 '25 21:12

Valentin Macé


1 Answers

You can do this with a vmapped three-term jnp.where statement. For example:

import jax.numpy as jnp
import jax

arr = jnp.ones(shape=(2,10)) * 2
idx = jnp.array([2, 4])

@jax.vmap
def f(row, ind):
  return jnp.where(jnp.arange(len(row)) < ind, row, -1)

f(arr, idx)
# DeviceArray([[ 2.,  2., -1., -1., -1., -1., -1., -1., -1., -1.],
#              [ 2.,  2.,  2.,  2., -1., -1., -1., -1., -1., -1.]], dtype=float32)
like image 59
jakevdp Avatar answered Dec 30 '25 11:12

jakevdp



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!