So I have a custom layer in Keras that uses a Mask in it.
To get it to work with save/load I need to serialize the Mask correctly. So this standard code doesn't work:
def get_config(self):
    config =  {'mask': self.mask}
    base_config = super(Mixing, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
where mask is a reference to the Masking Layer.
I'm not sure how to serialize Masking (or Keras Layers in general). Can anyone help?
You can implement the same serializing methods as the built-in Wrapper class.
def get_config(self):
    config = {'layer': {'class_name': self.layer.__class__.__name__,
                        'config': self.layer.get_config()}}
    base_config = super(Wrapper, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
    from . import deserialize as deserialize_layer
    layer = deserialize_layer(config.pop('layer'),
                              custom_objects=custom_objects)
    return cls(layer, **config)
During serialization, in get_config, the inner layer's class name and config are saved in config['layer'].
In from_config, the inner layer is deserialized with deserialize_layer using config['layer'].
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With