I have developed two collate functions to read in data from h5py files (I tried to create some synthetic data for a MWE here but it is not going to plan).
The difference between the two in processing my data is about 10x — a very large increase and I am unsure as to why and I am curious for insights for my future collate functions.
def slow(batch): ''' This function retrieves the data emitted from the H5 torch data set. It alters the emitted dimensions from the dataloader from: [batch_sz, layers, tokens, features], to: [layers, batch_sz, tokens, features] ''' embeddings = [] start_ids = [] end_ids = [] idxs = [] for i in range(len(batch)): embeddings.append(batch[i]['embeddings']) start_ids.append(batch[i]['start_ids']) end_ids.append(batch[i]['end_ids']) idxs.append(batch[i]['idx']) # package data; # swap to expected [layers, batch_sz, tokens, features] sample = {'embeddings': torch.as_tensor(embeddings).permute(1, 0, 2, 3), 'start_ids': torch.as_tensor(start_ids), 'end_ids': torch.as_tensor(end_ids), 'idx': torch.as_tensor(idxs)} return sample
I thought the one below, with more loops, would be slower, but it is far from the case.
def fast(batch): ''' This function alters the emitted dimensions from the dataloader from: [batch_sz, layers, tokens, features] to: [layers, batch_sz, tokens, features] for the embeddings ''' # turn data to tensors embeddings = torch.stack([torch.as_tensor(item['embeddings']) for item in batch]) # swap to expected [layers, batch_sz, tokens, features] embeddings = embeddings.permute(1, 0, 2, 3) # get start ids start_ids = torch.stack([torch.as_tensor(item['start_ids']) for item in batch]) # get end ids end_ids = torch.stack([torch.as_tensor(item['end_ids']) for item in batch]) # get idxs idxs = torch.stack([torch.as_tensor(item['idx']) for item in batch]) # repackage sample = {'embeddings': embeddings, 'start_ids': start_ids, 'end_ids': end_ids} return sample
Edit: I tried swapping to this: It still is about 10x slower compared to ‘fast’.
def slow(batch): ''' This function retrieves the data emitted from the H5 torch data set. It alters the emitted dimensions from the dataloader from: [batch_sz, layers, tokens, features], to: [layers, batch_sz, tokens, features] ''' embeddings = [] start_ids = [] end_ids = [] idxs = [] for item in batch: embeddings.append(item['embeddings']) start_ids.append(item['start_ids']) end_ids.append(item['end_ids']) idxs.append(item['idx']) # package data; # swap to expected [layers, batch_sz, tokens, features] sample = {'embeddings': torch.as_tensor(embeddings).permute(1, 0, 2, 3), 'start_ids': torch.as_tensor(start_ids), 'end_ids': torch.as_tensor(end_ids), 'idx': torch.as_tensor(idxs)} return sample
Advertisement
Answer
See this answer (and give it an upvote): https://stackoverflow.com/a/30245465/10475762
Particularly the line: “In other words and in general, list comprehensions perform faster because suspending and resuming a function’s frame, or multiple functions in other cases, is slower than creating a list on demand.”
So in your case, you’re calling append multiple times each collate, which is called quite a few times in your training/testing/evaluation steps which all adds up. IMO, always avoid for loops whenever you can as it seems to somehow invariably lead to slowdowns.