I have a custom Pytorch dataset that returns a dictionary containing a class object "queries".
class QueryDataset(torch.utils.data.Dataset):
def __init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets
def __len__(self):
return self.values.shape[0]
def __getitem__(self, idx):
sample = DeviceDict({'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]})
return sample
The problem is that when I put the queries in a data loader I get default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'>. Is there a way to have a class object in my data loader? It blows up at next(iterator) in the code below.
train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
batch_size=10],
shuffle=True,
drop_last=False)
for i in range(epochs):
iterator = iter(train_loader)
for i in range(len(train_loader)):
batch = next(iterator)
out = model(batch)
loss = criterion(out["pred"], batch["targets"])
self.optimizer.zero_grad()
loss.sum().backward()
self.optimizer.step()
You need to define your own colate_fn in order to do this. A sloppy approach just to show you how stuff works here, would be something like this:
import torch
class DeviceDict:
def __init__(self, data):
self.data = data
def print_data(self):
print(self.data)
class QueryDataset(torch.utils.data.Dataset):
def __init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets
def __len__(self):
return 5
def __getitem__(self, idx):
sample = {'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]}
return sample
def custom_collate(dict):
return DeviceDict(dict)
dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()
Basically colate_fn allows you to achieve custom batching or adding support for custom data types as explained in the link I previously provided.
As you see it just shows the concept, you need to change it based on your own needs.
For those curious, this is the DeviceDict and custom collate function that I used to get things to work.
class DeviceDict(dict):
def __init__(self, *args):
super(DeviceDict, self).__init__(*args)
def to(self, device):
dd = DeviceDict()
for k, v in self.items():
if torch.is_tensor(v):
dd[k] = v.to(device)
else:
dd[k] = v
return dd
def collate_helper(elems, key):
if key == "query":
return elems
else:
return torch.utils.data.dataloader.default_collate(elems)
def custom_collate(batch):
elem = batch[0]
return DeviceDict({key: collate_helper([d[key] for d in batch], key) for key in elem})
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With