Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is it possible to compute per-example gradients efficiently in TensorFlow, in just one graph run?

TD;DR: is there a way to evaluate f'(x1), f'(x2), ..., f'(xn) in just one graph run, in a vectorized form? Where f'(x) is the derivative of f(x).

Something like:

x = tf.placeholder(tf.float32, shape=[100])
f = tf.square(x)
f_grad = tf.multiple_gradients(x) # f_grad contains f'(x[0]), f'(x[1]), ...

More specifically, I'm trying to implement Black Box Stochastic Variational Inference (BBSVI) manually (I know I could use a library like Edward, but I'm trying to implement it myself). At one point, I need to compute the mean of f'(x)g(x) across many different values of x (x1, x2, ..., xn), where f(x) and g(x) are two functions, and f'(x) is the derivative of f(x).

Using TensorFlow's autodiff feature, I can compute f'(x1), f'(x2), ..., f'(xn), by simply calling f_prime.eval(feed_dict={x: xi}) once for each value xi in (x1, x2, ..., xn). This is not efficient at all: I would like to use a vectorized form instead, but I'm not sure how to do this.

Perhaps using tf.stop_gradient() somehow? Or using the grad_ys argument in tf.gradients()?

like image 862
MiniQuark Avatar asked Jan 29 '26 00:01

MiniQuark


2 Answers

After a bit of digging, it seems that it is not trivial to compute per-example gradients in TensorFlow, because this library performs standard back-propagation to compute the gradients (as do other deep learning libraries like PyTorch, Theano and so on), which never actually computes the per-example gradients, it directly obtains the sum of the per-example gradients. Check out this discussion for more details.

However, there are some techniques to work around this issue, at least for some use cases. For example, the paper Efficient per-example gradient computation by Ian Goodfellow explains how to efficiently compute per-example vectors containing the sum of squared derivatives. Here is an excerpt from the paper showing the computation (but I highly encourage you read the paper, it is very short):

enter image description here

This algorithm is O(mnp) instead of O(mnp²), where m is the number of examples, n is the number of layers in the neural net, and p is the number of neurons per layer. So it is much faster than the naive approach (i.e., performing back-prop once per example), especially when p is large, and even more when using a GPU (which speeds up vectorized approaches by a large factor).

like image 179
MiniQuark Avatar answered Jan 30 '26 15:01

MiniQuark


On Tensorflow, it exactly demonstrates the example for per example gradient as below:

# Computing per-example gradients
batch_size = 10
num_features = 32
layer = tf.keras.layers.Dense(1)

def model_fn(arg):
  with tf.GradientTape() as g:
    inp, label = arg
    inp = tf.expand_dims(inp, 0)
    label = tf.expand_dims(label, 0)
    prediction = layer(inp)
    loss = tf.nn.l2_loss(label - prediction)
  return g.gradient(loss, (layer.kernel, layer.bias))

inputs = tf.random.uniform([batch_size, num_features])
labels = tf.random.uniform([batch_size, 1])
per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
assert per_example_gradients[0].shape == (batch_size, num_features, 1)
assert per_example_gradients[1].shape == (batch_size, 1)

You can refer to the official link for further details using vectorized_map.

like image 28
amirsina torfi Avatar answered Jan 30 '26 13:01

amirsina torfi