Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch: Learnable threshold for clipping activations

Tags:

pytorch

What is the proper way to clip ReLU activations with a learnable threshold? Here's how I implemented it, however I'm not sure if this is correct:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.act_max = nn.Parameter(torch.Tensor([0]), requires_grad=True)

        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(64 * 5 * 5, 10)

    def forward(self, input):
        conv1 = self.conv1(input)
        pool1 = self.pool(conv1)
        relu1 = self.relu(pool1)

        relu1[relu1 > self.act_max] = self.act_max

        conv2 = self.conv2(relu1)
        pool2 = self.pool(conv2)
        relu2 = self.relu(pool2)
        relu2 = relu2.view(relu2.size(0), -1)
        linear = self.linear(relu2)
        return linear


model = Net()
torch.nn.init.kaiming_normal_(model.parameters)
nn.init.constant(model.act_max, 1.0)
model = model.cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(100):
    for i in range(1000):
        output = model(input)
        loss = nn.CrossEntropyLoss()(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.act_max.data = model.act_max.data - 0.001 * model.act_max.grad.data

I had to add the last line because without it the value would not update for some reason.

UPDATE: I am now trying a method to compute the uppper bound (act_max) based on the gradients for activations:

  1. For all activations above the threshold (relu1[relu1 > self.act_max]), look at their gradients: compute the average direction all these gradients point to.
  2. For all positive activations below the threshold, compute the average gradient of which direction they want to change to.
  3. The sum of these average gradients determines the direction and magnitude of the change for act_max.
like image 663
MichaelSB Avatar asked Sep 07 '25 11:09

MichaelSB


1 Answers

There are two problems with that code.

  1. The implementation-level one is that you're using an in-place operation which generally doesn't work well with autograd. Instead of

relu1[relu1 > self.act_max] = self.act_max

you should use an out-of-place operation like

relu1 = torch.where(relu1 > self.act_max, self.act_max, relu1)

  1. The other is more general : neural networks are generally trained with gradient descent methods and threshold values can have no gradient - the loss function is not differentiable with respect to thresholds.

In your model you're using a dirty hackaround (whether you write is as it is or use torch.where) - model.act_max.grad.data is only defined because for some elements their value is set to model.act_max. But this gradient knows nothing about why they were set to that value. To make things more concrete, lets define cutoff operation C(x, t) which defines whether x is above or below threshold t

C(x, t) = 1 if x < t else 0

and write your clipping operation as a product

clip(x, t) = C(x, t) * x + (1 - C(x, t)) * t

you can then see that the threshold t has twofold meaning: it controls when to cutoff (inside C) and it controls the value above cutoff (the trailing t). We can therefore generalize the operation as

clip(x, t1, t2) = C(x, t1) * x + (1 - C(x, t1)) * t2

The problem with your operation is that it is only differentiable with respect to t2 but not t1. Your solution ties the two together so that t1 == t2, but it is still the case that gradient descent will act as if there was no changing the threshold, only changing the above-the-threshold-value.

For this reason, in general your thresholding operation may not be learning the value you would hope it learns. This is something to keep in mind when developing your operations, but not a guarantee of failure - in fact, if you consider the standard ReLU on biased output of some linear unit, we get a similar picture. We define the cutoff operation H

H(x, t) = 1 if x > t else 0

and ReLU as

ReLU(x + b, t) = (x + b) * H(x + b, t) = (x + b) * H(x, t - b)

where we could again generalize to

ReLU(x, b, t) = (x + b) * H(x, t)

and again we can only learn b and t is implicitly following b. Yet it seems to work :)

like image 63
Jatentaki Avatar answered Sep 11 '25 00:09

Jatentaki