Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch: How to get all data and targets for subsets

Tags:

python

pytorch

I used the following code to read the dataset from a specific folder and divide it to train and test subsets. I can get all data and targets for each subset using list comprehension, but it is very slow for large data. Is there any other fast approach to do this?

def train_test_dataset(dataset, test_split=0.20):
    train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=test_split, stratify=dataset.targets)
    datasets = {}
    train_dataset = Subset(dataset, train_idx)
    test_dataset = Subset(dataset, test_idx)

    return train_dataset, test_dataset


dataset = dset.ImageFolder("/path_to_folder", transform = transform)
    
train_set, test_set = train_test_dataset(dataset)

train_data = [data for data, _ in train_set]
train_labels = [label for _, label in train_set]

I've tried this approach using DataLoader, it is better but it also takes some time: PyTorch Datasets: Converting entire Dataset to NumPy

Thank you.

like image 235
Angelus Avatar asked Sep 05 '25 03:09

Angelus


2 Answers

The answer in the link you provided basically defeats the purpose of having a data loader: a data loader is meant to load your data to memory chunk by chunk. This has the clear advantage of not having to load the dataset in its entirety at a given moment.

From your ImageFolder dataset you can split your data with the torch.utils.data.random_split function:

>>> def train_test_dataset(dataset, test_split=.2):
...    test_len = int(len(dataset)*test_split)
...    train_len = len(dataset) - test_len 
...    return random_split(dataset, [train_len, test_len])

Then you can plug those datasets in separate DataLoaders:

>>> train_set, test_set = train_test_dataset(dataset)

>>> train_dl = DataLoader(train_set, batch_size=16, shuffle=True)
>>> test_dl  = DataLoader(train_set, batch_size=32 shuffle=False)
like image 122
Ivan Avatar answered Sep 07 '25 19:09

Ivan


There is a different way to get data from torch.utils.data.dataset.Subset . Subsets are a map to the dataset (like linked lists) ponting to the main dataset. so we first get the indices from the subset then pass it onto dataset:

all_dataset.labels[train_dataset.indices]

example:

class LesionDataset(torch.utils.data.Dataset):
        self.labels =  df.values[...,1:].argmax(axis=1)
        self.images = df.values[...,0]
        .
        .


all_dataset = Datasets.LesionDataset("data/img",
                    "data/img/ALL.csv")
train_dataset, val_dataset = random_split(all_dataset,[train_size, val_size])
all_dataset.labels[train_dataset.indices]
like image 33
Ashiq A Avatar answered Sep 07 '25 21:09

Ashiq A