Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using tf.batch_scatter_add in Keras

I'm looking for a way to use tf.scatter_add with Keras batches. Shape of outputs: (?, 1000) and shapes of indices and updates are (?, 100) each.

Try1: Using Keras tensors

vals = tf.scatter_add(outputs, indices, updates)

This throws an error:

'Tensor' object has no attribute '_lazy_read'

Try2: Tried using k.variable which should be updatable

vals = K.variable(outputs)
vals = tf.scatter_add(vals, inputs[1], inputs[2]) 

ValueError: initial_value must have a shape specified:
Tensor("scatter_add_43/zeros_like:0", shape=(?, 1000), dtype=float32))

Any clues? Scatter_add and batch_scatter_add result in the same errors. Will I need to write a custom layer for this? Seems like even that will run into one of the above errors.

like image 311
N. Sawant Avatar asked Jan 30 '26 08:01

N. Sawant


1 Answers

You don't need a Layer for this. You can use tf.tensor_scatter_add to update Tensors directly or scatter_nd_add (method from tf.Variable) to update Variables.

References:

https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_add https://www.tensorflow.org/api_docs/python/tf/Variable#scatter_nd_add

Unfortunately, Tensorflow API does not have a tf.batched_scatter_nd_add yet. So we have to try mimic this behaviour. Both methods work with individual or slices. In order to work with batches, we need to transform the indices to the expected multi-dimensional expected.

import tensorflow as tf
tf.enable_eager_execution()

batch_size = 2
vocab_size = 6
seq_len = 3

predictions = tf.Variable(tf.zeros([batch_size, vocab_size], dtype=tf.float32))

indices = tf.constant([[0, 2, 4], [1, 3, 5]], dtype=tf.int32)
updates = tf.ones([batch_size, seq_len], dtype=tf.float32)

batched_indices = [[i, j.numpy()] for i, indexes in enumerate(indices) for j in indexes]
batched_updates = tf.reshape(updates, [-1])  # flatten

predictions.scatter_nd_add(indices=batched_indices, updates=batched_updates)

print(predictions)

Output:

<tf.Variable 'Variable:0' shape=(2, 6) dtype=float32, numpy=
array([[1., 0., 1., 0., 1., 0.],
       [0., 1., 0., 1., 0., 1.]], dtype=float32)>
like image 120
user2936263 Avatar answered Jan 31 '26 22:01

user2936263