Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why do we need state_dict = state_dict.copy()

I want to load the weights of a pre-trained model on my local model. I don’t understand why state_dict = state_dict.copy() is necessary if the two networks have the same name state_dict.

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
    state_dict._metadata = metadata

def load(module, prefix=''):

    local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
    module._load_from_state_dict(
        state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
    for name, child in module._modules.items():
        if child is not None:
            load(child, prefix + name + '.')
start_prefix = ''
# print("hasattr(model, 'bert')",hasattr(model, 'bert')  ) :false
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
    start_prefix = 'bert.'
load(model, prefix=start_prefix)

Note: the above code is from Hugging Face.

like image 374
dan Avatar asked Oct 19 '25 19:10

dan


1 Answers

state_dict = state_dict.copy()

does exactly what you tell him to do: it copies in place the state_dict. State dict are all the parameters of your model, and copying it allows to make them independant. One should be careful whether you need a copy or a deepcopy though !

like image 120
Marine Galantin Avatar answered Oct 21 '25 12:10

Marine Galantin



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!