Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Custom optimizer in PyTorch or TensorFlow 2.12.0

I am trying to implement my custom optimizer in PyTorch or TensorFlow 2.12.0. With help of ChatGPT I always get code that have errors, what's more I can't find any useful examples.

I would like to implement custom optimizer as:

d1 contains sign of current derivatives
d2 contains sign of previous derivatives
step_size is 1.0
step_size is divided by 2.0 if sign of d1 != d2 In PyTorch

I know that code has to look something like this:

import torch.optim as optim

class MyOpt(optim.Optimizer):
   def __init__(self, params, lr=1.0):
      defaults = dict(lr=lr, d1=None, d2=None)
      super(MyOpt, slef).__init__(params, defaults)

   def step(self):
      ???

Can anyone help me to code it ?

like image 334
Michal Avatar asked Oct 14 '25 08:10

Michal


1 Answers

The pytorch optimizers are complicated because they are general and optimized for performance. Implementing your own is relatively straightforward if you don't need it to be particularly robust. For example, a very simple implementation of gradient descent with momentum might look something like this.

from torch.optim import Optimizer

class SimpleGD(Optimizer):
    def __init__(self, params, lr, momentum=0.0):
        defaults = dict(lr=lr, momentum=momentum)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for param in group['params']:
                grad = param.grad + group['momentum'] * self.state[param].get('momentum_buffer', 0)
                self.state[param]['momentum_buffer'] = grad
                param -= group['lr'] * grad
        return loss

Which should match the results of PyTorch's SGD with the same lr and momentum (and everything else defaulted).

For your case, you would need to compute sign of the gradients and cache them at each step similar to the way the momentum_buffer is cached in the above example.

like image 66
jodag Avatar answered Oct 16 '25 21:10

jodag