Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pretrained NN Finetuning with Keras. How to freeze Batch Normalization?

So I didnt write my code in tf.keras and according to this tutorial for finetuning with a pretrained NN: https://keras.io/guides/transfer_learning/#freezing-layers-understanding-the-trainable-attribute,

I have to set the parameter training=False when calling the pretrained model, so that when I later unfreeze for finetuning, Batch Normalization doesnt destroy my model. But how do I do that in keras (Remember: I didnt write it in tf.keras). Is it even necessary in keras to do that?

The code:

def baseline_model():
    pretrained_model = Xception(include_top=False, weights="imagenet")

    for layer in pretrained_model.layers:
        layer.trainable = False

    general_input = Input(shape=(256, 256, 3))

    x = pretrained_model(general_input,training=False)
    x = GlobalAveragePooling2D()(x)
...

Gives me the error, when calling model = baseline_model():

TypeError: call() got an unexpected keyword argument 'training'

How do I do that best? I tried rewriting everything in tf.keras, but theres errors popping up everyhwere when I tried to do it...

EDIT: My keras version is 2.3.1 and tensorflow 2.2.0.

like image 288
jackbauer Avatar asked Dec 01 '25 10:12

jackbauer


1 Answers

EDITED my previous answer after doing some additional research:
I did some reading and it seems like there is some trickery in how BatchNorm layer behaves when frozen. This is a good thread talking about it: github.com/keras-team/keras/issues/7085 seems like training=false parameter is necessary to correctly freeze BatchNorm layer and it was added in Keras 2.1.3, so my advice for you is to make sure your Keras/TF version is higher

like image 91
Karol Żak Avatar answered Dec 03 '25 23:12

Karol Żak



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!