Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Modify i-th next tensor values every time a value 1 appears in a tensor

I have two tensors with the same size:

a = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b = [0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1]

Tensor a has three regions which are demarked by consecutive values: region 1 is [1,2,3,4,5], region 2 is [10,11,12,13] and region 3 is [20, 21, 22, 23, 24, 25, 26, 27, 28].

For each of those regions, I want to apply the following logic: if one of the values of b is 1, then the following i values are set to 0. If they are already 0, they continue as 0. After i values are changed, nothing happens until another value of b is 1. In that case, the next i values are forced to 0...

Some examples:

# i = 1

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 1, 0,  1,  0,  1,  0,  1,  0,  1,  0,  1,  0,  0,  0,  1]


# i = 2

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 1,  1,  0,  0,  0,  1,  0,  0,  1,  0,  0,  0,  0,  1]


# i = 4

a     = [1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28]
b_new = [0, 1, 0, 0, 0,  1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  1]

Not sure if this would help, but I was able to separate the regions into segments by doing:

a_shifted = tf.roll(a - 1, shift=-1, axis=0)
a_shifted_segs = tf.math.cumsum(tf.cast(a_shifted != a, dtype=tf.int64), exclusive=True)

# a_shifted_segs = 
= [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]

Do you know any way of doing this efficiently?

like image 557
Kunis Avatar asked Oct 16 '25 03:10

Kunis


2 Answers

Here you have a tensorflow solution, based on tf.scan. I know the conditionals are a bit complicated, if you have suggestions how to simplify, I'm open for suggestions. However, if you know how to read the conditionals, it should be quite clear what the code does.

Here, the variable i tells us, for each position in the array, how many more b values have to overwritten with 0.

import tensorflow as tf 

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

# Extract switches inside a
switches = tf.scan(
    lambda e, new_a: {'a': new_a, 'out': new_a != (e['a']+1)}, 
    a, 
    initializer={'a': tf.reduce_min(a)-2, 'out': tf.constant(False)}
)['out']

# Define inputs for the scan iterations
initializer = {'b': tf.constant(False), 'i': tf.constant(0)}
elems = {'switches': switches, 'b': tf.cast(b, dtype=tf.bool)}

@tf.function
def step(last_out, new_in, max_i):
    new_i = tf.cond(
        last_out['i'] > 0, # If we are currently overwriting with 0
        lambda: tf.cond(
            new_in['switches'], # Is there a segment switch?
            lambda: tf.cond( # if switches:
                new_in['b'], # Check if b == 1
                lambda: tf.constant(max_i), # if b == 1: i = max_i
                lambda: tf.constant(0) # if b == 0: i = 0
            ),
            lambda: tf.maximum(last_out['i']-1, 0) # If no switch, decrement i
        ),
        lambda: tf.cond( # if we are currently not overwriting with 0
            new_in['b'], # check if b == 1
            lambda: tf.constant(max_i), # if b == 1: i = max_i
            lambda: tf.constant(0) # if b == 0: i = 0
        )
    )
    b = tf.cond(
        tf.equal(new_i, max_i), # Have we just reset i ?
        lambda: tf.constant(True), # If yes, we want to write b = 1
        lambda: tf.constant(False) # Otherwise, we write b = 0
    )
    
    return {'b': b, 'i': new_i}

Examples:

outp_1 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=1), elems=elems, initializer=initializer)
print( tf.cast(outp_1['b'], tf.int32) )
# tf.Tensor([0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 0 0 1], shape=(18,), dtype=int32)

outp_2 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=2), elems=elems, initializer=initializer)
print( tf.cast(outp_2['b'], tf.int32) )
# tf.Tensor([0 1 0 0 1 1 0 0 0 1 0 0 1 0 0 0 0 1], shape=(18,), dtype=int32)

outp_4 = tf.scan(lambda _last, _inp: step(_last, _inp, max_i=4), elems=elems, initializer=initializer)
print( tf.cast(outp_4['b'], tf.int32) )
# tf.Tensor([0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1], shape=(18,), dtype=int32)

This answer is sponsored by lambda.

like image 58
André Avatar answered Oct 17 '25 16:10

André


Here is a pure Tensorflow approach, which will work in Eager Execution and Graph mode:

# copy, paste, acknowledge

import tensorflow as tf

def split_regions_and_modify(a, b, i):
  indices = tf.squeeze(tf.where(a[:-1] != a[1:] - 1), axis=-1) + 1
  row_splits = tf.cast(tf.cond(tf.not_equal(tf.shape(indices)[0], 0), 
                    lambda: tf.concat([indices, [indices[-1] + (tf.cast(tf.shape(a), dtype=tf.int64)[0] - indices[-1])]], axis=0), 
        lambda: tf.shape(a)[0][None]), dtype=tf.int32)

  def body(i, j, k, tensor, row_splits):
    k = tf.cond(tf.equal(row_splits[k], j), lambda: tf.add(k, 1), lambda: k)
    current_indices = tf.range(j + 1, tf.minimum(j + 1 + i, row_splits[k]), dtype=tf.int32)

    tensor = tf.cond(tf.logical_and(tf.equal(tensor[j], 1), tf.not_equal(j,  row_splits[k])), lambda: 
                  tf.tensor_scatter_nd_update(tensor, current_indices[..., None], tf.zeros_like(current_indices)), lambda: tensor)
    return i, tf.add(j, 1), k, tensor, row_splits 

  j0 = tf.constant(0)
  k0 = tf.constant(0)
  c = lambda i, j0, k0, b, row_splits: tf.logical_and(tf.less(j0, tf.shape(b)[0]), tf.less(k0, tf.shape(row_splits)[0]))
  _, _, _, output, _ = tf.while_loop(c, body, loop_vars=[i, j0, k0, b, row_splits])
  return output

Usage:

a = tf.constant([1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28])
b = tf.constant([0, 1, 1, 1, 1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0,  0,  0,  1])

split_regions_and_modify(a, b, 1)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 2)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int32)>

split_regions_and_modify(a, b, 4)
# <tf.Tensor: shape=(18,), dtype=int32, numpy=array([0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)>
like image 35
AloneTogether Avatar answered Oct 17 '25 17:10

AloneTogether