Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I create a custom keras optimizer?

I'm working on comparing the performances of SVRG, SAG and other optimizers for deep learning minimization.

How can I implement custom optimizers with keras, I tried looking at the SGD keras implementation here source code but couldn't find the source code for tf.raw_ops.ResourceApplyGradientDescent which makes it difficult to reproduce for another optimizer.

like image 808
Nawel Avatar asked Dec 01 '25 20:12

Nawel


1 Answers

To customize an optimizer:

  • Extend tf.keras.optimizers.Optimizer.
  • Override _create_slots: This for creating optimizer variable for each trainable variable. Would be useful if you need to add momentum to your optimizer.
  • Override _resource_apply_dense or _resource_apply_sparse to do the actual update and the equation of your optimizer.
  • get_config (Optional): store the paramters you pass to the optimizer so that you can clone, or save your model afterwards.

Here is a quick example of SGD with momentum taken from here

class MyMomentumOptimizer(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.001, momentum=0.9, name="MyMomentumOptimizer", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        self._set_hyper("decay", self._initial_decay) # 
        self._set_hyper("momentum", momentum)
    
    def _create_slots(self, var_list):
        """For each model variable, create the optimizer variable associated with it.
        TensorFlow calls these optimizer variables "slots".
        For momentum optimization, we need one momentum slot per model variable.
        """
        for var in var_list:
            self.add_slot(var, "momentum")

    @tf.function
    def _resource_apply_dense(self, grad, var):
        """Update the slots and perform one optimization step for one model variable
        """
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype) # handle learning rate decay
        momentum_var = self.get_slot(var, "momentum")
        momentum_hyper = self._get_hyper("momentum", var_dtype)
        momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)
        var.assign_add(momentum_var * lr_t)

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        }
like image 88
Coderji Avatar answered Dec 04 '25 03:12

Coderji



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!