I want to load the weights of a pre-trained model on my local model. I don’t understand why state_dict = state_dict.copy() is necessary if the two networks have the same name state_dict.
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
# print("hasattr(model, 'bert')",hasattr(model, 'bert') ) :false
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
Note: the above code is from Hugging Face.
state_dict = state_dict.copy()
does exactly what you tell him to do: it copies in place the state_dict. State dict are all the parameters of your model, and copying it allows to make them independant. One should be careful whether you need a copy or a deepcopy though !
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With