Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to require gradient only for some tensor elements in a pytorch tensor?

I like to use a tensor with only a few variable elements which are considered during the backpropagation step. Consider for example:

self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)
mask = torch.zeros(self.conv1.weight.data.shape, requires_grad=False)
self.conv1.weight.data[0, 0, 0, 0] += mask[0, 0, 0, 0]
print(self.conv1.weight.data[0, 0 , 0, 0].requires_grad) 

It will be output False

like image 701
ZeroorOne Avatar asked Oct 23 '25 03:10

ZeroorOne


1 Answers

You can only switch on and off gradient computation at the tensor level which means that the requires_grad is not element-wise. What you observe is different because you have accessed the requires_grad attribute of conv1.weight.data which is not the same object as its wrapper tensor conv1.weight!

Notice the difference:

>>> conv1 = nn.Conv2d(3, 16, 3) # requires_grad=True by default

>>> conv1.weight.requires_grad
True

>>> conv1.weight.data.requires_grad
False

conv1.weight is the weight tensor while conv1.weight.data is the underlying data tensor which never requires gradient because it is at a different level.


Now onto how to solve the partially requiring gradient computation on a tensor. Instead of looking at solving it as "only require gradient for some elements of tensor", you can think of it as "don't require gradient for some elements of tensor". You can do so by overwriting the gradient values on the tensor at the desired positions after the backward pass:

>>> conv1 = nn.Conv2d(3, 1, 2)
>>> mask = torch.ones_like(conv1.weight)

For example, to prevent the update of the first component of the convolutional layer:

>>> mask[0,0,0,0] = 0 

After the backward pass, you can mask the gradient on conv1.weight:

>>> conv1(torch.rand(1,3,10,10)).mean().backward()
>>> conv1.weight.grad *= mask
like image 103
Ivan Avatar answered Oct 24 '25 18:10

Ivan



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!