I'm trying to follow the fine-tuning steps described in https://www.tensorflow.org/tutorials/images/transfer_learning#create_the_base_model_from_the_pre-trained_convnets to get a trained model for binary segmentation.
I create an encoder-decoder with the weights of the encoder being the ones of the MobileNetV2 and fixed as encoder.trainable = False
. Then, I define my decoder as said in the tutorial and I train the network for 300 epochs using a learning rate of 0.005. I get the following loss value and Jaccard index during the lasts epochs:
Epoch 297/300
55/55 [==============================] - 85s 2s/step - loss: 0.2443 - jaccard_sparse3D: 0.5556 - accuracy: 0.9923 - val_loss: 0.0440 - val_jaccard_sparse3D: 0.3172 - val_accuracy: 0.9768
Epoch 298/300
55/55 [==============================] - 75s 1s/step - loss: 0.2437 - jaccard_sparse3D: 0.5190 - accuracy: 0.9932 - val_loss: 0.0422 - val_jaccard_sparse3D: 0.3281 - val_accuracy: 0.9776
Epoch 299/300
55/55 [==============================] - 78s 1s/step - loss: 0.2465 - jaccard_sparse3D: 0.4557 - accuracy: 0.9936 - val_loss: 0.0431 - val_jaccard_sparse3D: 0.3327 - val_accuracy: 0.9769
Epoch 300/300
55/55 [==============================] - 85s 2s/step - loss: 0.2467 - jaccard_sparse3D: 0.5030 - accuracy: 0.9923 - val_loss: 0.0463 - val_jaccard_sparse3D: 0.3315 - val_accuracy: 0.9740
I store all the weights of this model and then, I compute the fine-tuning with the following steps:
model.load_weights('my_pretrained_weights.h5')
model.trainable = True
model.compile(optimizer=Adam(learning_rate=0.00001, name='adam'),
loss=SparseCategoricalCrossentropy(from_logits=True),
metrics=[jaccard, "accuracy"])
model.fit(training_generator, validation_data=(val_x, val_y), epochs=5,
validation_batch_size=2, callbacks=callbacks)
Suddenly the performance of my model is way much worse than during the training of the decoder:
Epoch 1/5
55/55 [==============================] - 89s 2s/step - loss: 0.2417 - jaccard_sparse3D: 0.0843 - accuracy: 0.9946 - val_loss: 0.0079 - val_jaccard_sparse3D: 0.0312 - val_accuracy: 0.9992
Epoch 2/5
55/55 [==============================] - 90s 2s/step - loss: 0.1920 - jaccard_sparse3D: 0.1179 - accuracy: 0.9927 - val_loss: 0.0138 - val_jaccard_sparse3D: 7.1138e-05 - val_accuracy: 0.9998
Epoch 3/5
55/55 [==============================] - 95s 2s/step - loss: 0.2173 - jaccard_sparse3D: 0.1227 - accuracy: 0.9932 - val_loss: 0.0171 - val_jaccard_sparse3D: 0.0000e+00 - val_accuracy: 0.9999
Epoch 4/5
55/55 [==============================] - 94s 2s/step - loss: 0.2428 - jaccard_sparse3D: 0.1319 - accuracy: 0.9927 - val_loss: 0.0190 - val_jaccard_sparse3D: 0.0000e+00 - val_accuracy: 1.0000
Epoch 5/5
55/55 [==============================] - 97s 2s/step - loss: 0.1920 - jaccard_sparse3D: 0.1107 - accuracy: 0.9926 - val_loss: 0.0215 - val_jaccard_sparse3D: 0.0000e+00 - val_accuracy: 1.0000
Is there any known reason why this is happening? Is it normal? Thank you in advance!
OK I found out what I do different that makes it NOT necessary to compile. I do not set encoder.trainable = False. What I do in the code below is equivalent
for layer in encoder.layers:
layer.trainable=False
then train your model. Then you can unfreeze the encoder weights with
for layer in encoder.layers:
layer.trainable=True
You do not need to recompile the model. I tested this and it works as expected. You can verify by priniting model summary before and after and look at the number of trainable parameters. As for changing the learning rate I find it is best to use the the keras callback ReduceLROnPlateau to automatically adjust the learning rate based on validation loss. I also recommend using the EarlyStopping callback which monitors validation and halts training if the loss fails to reduce after 'patience' number of consecutive epochs. Setting restore_best_weights=True will load the weights for the epoch with the lowest validation loss so you don't have to save then reload the weights. Set epochs to a large number to ensure this callback activates. The code I use is shown below
es=tf.keras.callbacks.EarlyStopping( monitor="val_loss", patience=3,
verbose=1, restore_best_weights=True)
rlronp=tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=1,
verbose=1)
callbacks=[es, rlronp]
In model.fit set callbacks=callbacks
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With