I am using PytorchLightning and a ModelCheckpoint which saves models with a formatted filename like filename="model_{epoch}-{val_acc:.2f}"
Afterwards I want to access these checkpoints again and need their paths. For simplicity I want the best from save_top_k=N.
As the filename is dynamic I wonder how can I retrieve the checkpoint files easily.
Is there a built-in attribute in the ModelCheckpoint or the trainer that gives me the saved checkpoints, e.g. something like
paths = checkpoint_callback.get_top_k_paths() # ?
I know I can do it with glob and model_dir. As the callback has to keep track of them anyway I wonder what the built-in way is to keep track of the model paths.
you can retrieve the best model path after training from the checkpoint
# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
To find all the checkpoints you can get the list of files in the dirpath where the checkpoints are saved.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With