I have developed two collate functions to read in data from h5py files (I tried to create some synthetic data for a MWE here but it is not going to plan).
The difference between the two in processing my data is about 10x — a very large increase and I am unsure as to why and I am curious for insights for my future collate functions.
def slow(batch):
'''
This function retrieves the data emitted from the H5 torch data set.
It alters the emitted dimensions from the dataloader
from: [batch_sz, layers, tokens, features], to:
[layers, batch_sz, tokens, features]
'''
embeddings = []
start_ids = []
end_ids = []
idxs = []
for i in range(len(batch)):
embeddings.append(batch[i]['embeddings'])
start_ids.append(batch[i]['start_ids'])
end_ids.append(batch[i]['end_ids'])
idxs.append(batch[i]['idx'])
# package data; # swap to expected [layers, batch_sz, tokens, features]
sample = {'embeddings': torch.as_tensor(embeddings).permute(1, 0, 2, 3),
'start_ids': torch.as_tensor(start_ids),
'end_ids': torch.as_tensor(end_ids),
'idx': torch.as_tensor(idxs)}
return sample
I thought the one below, with more loops, would be slower, but it is far from the case.
def fast(batch):
''' This function alters the emitted dimensions from the dataloader
from: [batch_sz, layers, tokens, features]
to: [layers, batch_sz, tokens, features] for the embeddings
'''
# turn data to tensors
embeddings = torch.stack([torch.as_tensor(item['embeddings']) for item in batch])
# swap to expected [layers, batch_sz, tokens, features]
embeddings = embeddings.permute(1, 0, 2, 3)
# get start ids
start_ids = torch.stack([torch.as_tensor(item['start_ids']) for item in batch])
# get end ids
end_ids = torch.stack([torch.as_tensor(item['end_ids']) for item in batch])
# get idxs
idxs = torch.stack([torch.as_tensor(item['idx']) for item in batch])
# repackage
sample = {'embeddings': embeddings,
'start_ids': start_ids,
'end_ids': end_ids}
return sample
Edit: I tried swapping to this: It still is about 10x slower compared to ‘fast’.
def slow(batch):
'''
This function retrieves the data emitted from the H5 torch data set.
It alters the emitted dimensions from the dataloader
from: [batch_sz, layers, tokens, features], to:
[layers, batch_sz, tokens, features]
'''
embeddings = []
start_ids = []
end_ids = []
idxs = []
for item in batch:
embeddings.append(item['embeddings'])
start_ids.append(item['start_ids'])
end_ids.append(item['end_ids'])
idxs.append(item['idx'])
# package data; # swap to expected [layers, batch_sz, tokens, features]
sample = {'embeddings': torch.as_tensor(embeddings).permute(1, 0, 2, 3),
'start_ids': torch.as_tensor(start_ids),
'end_ids': torch.as_tensor(end_ids),
'idx': torch.as_tensor(idxs)}
return sample
Advertisement
Answer
See this answer (and give it an upvote): https://stackoverflow.com/a/30245465/10475762
Particularly the line: “In other words and in general, list comprehensions perform faster because suspending and resuming a function’s frame, or multiple functions in other cases, is slower than creating a list on demand.”
So in your case, you’re calling append multiple times each collate, which is called quite a few times in your training/testing/evaluation steps which all adds up. IMO, always avoid for loops whenever you can as it seems to somehow invariably lead to slowdowns.