Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Any Tensorflow equivalent of Pytorch's backward()? Trying to send gradients back to TF model to backprop

I'm trying to implement a split learning model, where my TF model on a client takes in the data and produces an intermediate output. This intermediate output will be sent to a server running the Pytorch model that will take it in as input and minimize the loss. Then, my server will send back the client gradients to the TF model for the TF model to update its weights.

How do I get my TF model to update its weights with the gradients sent back from the server?

# pytorch client
client_output.backward(client_grad)
optimizer.step()

With PyTorch, I can just do a client_pred.backward(client_grad) and client_optimizer.step().

How do I achieve the same with a Tensorflow client? I've tried GradientTape with tape.gradient(client_grad, model.trainable_weights) but it just gives me None. I think it's because there's no computation in the tape context and client_grad is just a Tensor holding the gradients and is not connected to the model's layers?

Is there some way I can do this with tf's apply_gradients() or compute_gradients()?

I only have the gradients for the client's last layer (sent by server). I'm trying to compute all the gradients for the client and update the weights.

Thank you.


class TensorflowModel(tf.keras.Model):
        def __init__(self, D_in, H, D_out):
            super(TensorflowModel, self).__init__()
            self.d1 = Dense(H, activation='relu', input_shape=(D_in,))
            self.d2 = Dense(D_out)

        def call(self, x):
            x = self.d1(x)
            return self.d2(x)

tensorflowModel = TensorflowModel(D_in, H, D_out)
tensorflowOptimizer = tf.optimizers.Adam(lr=1e-4)

serverModel = torch.nn.Sequential(
        torch.nn.Linear(10, 50),
        torch.nn.ReLU(),
        torch.nn.Linear(50, 10)
    )
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(serverModel.parameters(), lr=1e-4)

for t in range(N):
    // let x be minibatch
    // let y be labels of minibatch

    client_pred = tensorflowModel(x)

    client_output = torch.from_numpy(client_pred.numpy())
    client_output.requires_grad = True

    y_pred = serverModel(client_output)
    loss = loss_fn(y_pred, y)  
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() // update server weights

    // now retrieve client grad for last layer
    client_grad = client_output.grad.detach().clone().numpy()
    client_grad = tf.convert_to_tensor(client_grad) // change to tf tensor

    // now compute all client gradients and update client weights
    // HOW DO I DO THIS? 

How should I update the client weights? If the client was a pytorch model I could just do client_pred.backward(client_grad) and client_optimizer.step(). I'm not sure how to use the gradient tape to calculate gradients, since client_grad was computed on the server and was a pytorch tensor that's converted to a tf tensor.

like image 772
Flowchart Avatar asked Nov 24 '25 18:11

Flowchart


1 Answers

I don't know if you need this anymore. I came here searching for this myself. So here it goes.

I am not so good with torch so please bear with me.

client_model = keras.models.Sequential([keras.layers........., ......])
with tf.GradientTape(persistent=True) as client_tape:
    client_pred = client_model(batch_flat)

# Get your gradients from server
grad_from_server = youGradientGetterfunction()
client_gradients = client_tape.gradient(client_pred, 
                     client_model.trainable_weights,
                     output_gradients=grad_from_server)

Now you have gradients for every layers in your client side, you can use an optimiser like:

client_opt = tf.keras.optimizers.SGD(learning_rate=0.1)
client_opt.apply_gradients(zip(client_gradients,
                               client_model.trainable_weights))

This will apply the calculated gradients to all the layers in the client side model. Or, you may choose to apply the gradients manually using

w = w - g*lr

So, there you go, try this out.

Please check out Tensorflow GradientTape → Methods → gradients

I was trying implementing Split-learning Split learning for health, Vepakomma et al, which needed this.

like image 80
scsanty Avatar answered Nov 26 '25 08:11

scsanty



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!