Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

When to use prepare_data vs setup in pytorch lightning?

Pytorch's docs on Dataloaders only say, in the code

def prepare_data(self):
    # download
    ...

and

def setup(self, stage: Optional[str] = None):
    # Assign train/val datasets for use in dataloaders

Please explain the intended separation between prepare_data and setup, what callbacks may occur between them, and why put something in one over the other.

like image 730
Gulzar Avatar asked Sep 08 '25 00:09

Gulzar


1 Answers

If you look at the pseudo for the Trainer.fit function provided in the documentation page of LightningModule at § Hooks, you can read:

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()                                 ## <-- prepare_data

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    on_fit_start()
    setup("fit")                                       ## <-- setup
    configure_optimizers()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")

You can see prepare_data being called only for global_rank == 0, i.e. it is only called by a single processor. It turns out you can read from the documentation description of prepare_data:

LightningModule.prepare_data()
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Whereas setup is called on all processes as you can read from the pseudo-code above as well as its documentation description:

LightningModule.setup(stage=None)
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

like image 131
Ivan Avatar answered Sep 09 '25 20:09

Ivan