Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the meaning of "trainable_weights" in Keras?

If I freeze my base_model with trainable=false, I get strange numbers with trainable_weights.

Before freezing my model has 162 trainable_weights. After freezing, the model only has 2. I tied 2 layers to the pre-trained network. Does trainable_weights show me the layers to train? I find the number weird, when I see 2,253,335 Trainable params.

like image 897
glomba Avatar asked Oct 29 '25 18:10

glomba


2 Answers

Late to the party, but maybe this answer can be useful to others that might be googling this.

First, it is useful to distinguish between the quantity "Trainable params" one sees at the end of my_model.summary(), with the output of len(my_model.trainable_weights).

Maybe an example helps: let's say I have a model with VGG16 architecture.

my_model = keras.applications.vgg16.VGG16(
    weights="imagenet", 
    include_top=False
)

# take a look at model summary
my_model.summary()

You will see that there are 13 conv. layers that have trainable parameters. Acknowledging the fact that pooling/input layers do not have trainable parameters, i.e. no learning is needed for them. On the other hand, in each of those 13 layers, there are "weights" and "biases" that need to be learned, think of them as variables.

What len(my_model.trainable_weights) will give you is the number of trainable layers (if you will) multiplied by 2 (weights + bias).

In this case, if you print len(my_model.trainable_weights), you will get 26 as the answer. maybe we can think of 26 as the number of variables for the optimization, variables that can differ in the shape of course.

Now to connect trainable_weights to the total number of trainable parameters, one can try:

trainable_params = 0 
for layer in my_model.trainable_weights:
    trainable_params += layer.numpy().size
print(F"#{trainable_params = }")

You will get this number: 14714688. Which must be the "Trainable params" number you see at the end of my_model.summary().

like image 160
Alireza Amani Avatar answered Nov 01 '25 07:11

Alireza Amani


Trainable weights are the weights that will be learnt during the training process. If you do trainable=False then those weights are kept as it is and are not changed because they are not learnt. You might see some "strange numbers" because either you are using a pre-trained network that has its weights already learnt or you might be using random initialization when defining the model. When using transfer learning with pre-trained models a common practice is to freeze the weights of base model (pre-trained) and only train the extra layers that you add at the end.

like image 31
techytushar Avatar answered Nov 01 '25 08:11

techytushar



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!