Skip to content
Advertisement

Torch: Why is this collate function so much faster than this other one?

I have developed two collate functions to read in data from h5py files (I tried to create some synthetic data for a MWE here but it is not going to plan).

The difference between the two in processing my data is about 10x — a very large increase and I am unsure as to why and I am curious for insights for my future collate functions.

def slow(batch):
    '''
    This function retrieves the data emitted from the H5 torch data set.
    It alters the emitted dimensions from the dataloader
    from: [batch_sz, layers, tokens, features], to:
    [layers, batch_sz, tokens, features]
    '''
    embeddings = []
    start_ids = []
    end_ids = []
    idxs = []
    for i in range(len(batch)):
        embeddings.append(batch[i]['embeddings'])
        start_ids.append(batch[i]['start_ids'])
        end_ids.append(batch[i]['end_ids'])
        idxs.append(batch[i]['idx'])
    # package data;    # swap to expected [layers, batch_sz, tokens, features]
    sample = {'embeddings': torch.as_tensor(embeddings).permute(1, 0, 2, 3),
              'start_ids': torch.as_tensor(start_ids),
              'end_ids': torch.as_tensor(end_ids),
              'idx': torch.as_tensor(idxs)}
    return sample

I thought the one below, with more loops, would be slower, but it is far from the case.

def fast(batch):
    ''' This function alters the emitted dimensions from the dataloader
    from: [batch_sz, layers, tokens, features]
    to: [layers, batch_sz, tokens, features] for the embeddings
    '''
    # turn data to tensors
    embeddings = torch.stack([torch.as_tensor(item['embeddings']) for item in batch])
    # swap to expected [layers, batch_sz, tokens, features]
    embeddings = embeddings.permute(1, 0, 2, 3)
    # get start ids
    start_ids = torch.stack([torch.as_tensor(item['start_ids']) for item in batch])
    # get end ids
    end_ids = torch.stack([torch.as_tensor(item['end_ids']) for item in batch])
    # get idxs
    idxs = torch.stack([torch.as_tensor(item['idx']) for item in batch])
    # repackage
    sample = {'embeddings': embeddings,
              'start_ids': start_ids,
              'end_ids': end_ids}
    return sample

Edit: I tried swapping to this: It still is about 10x slower compared to ‘fast’.

def slow(batch):
    '''
    This function retrieves the data emitted from the H5 torch data set.
    It alters the emitted dimensions from the dataloader
    from: [batch_sz, layers, tokens, features], to:
    [layers, batch_sz, tokens, features]
    '''
    embeddings = []
    start_ids = []
    end_ids = []
    idxs = []
    for item in batch:
        embeddings.append(item['embeddings'])
        start_ids.append(item['start_ids'])
        end_ids.append(item['end_ids'])
        idxs.append(item['idx'])
    # package data;    # swap to expected [layers, batch_sz, tokens, features]
    sample = {'embeddings': torch.as_tensor(embeddings).permute(1, 0, 2, 3),
              'start_ids': torch.as_tensor(start_ids),
              'end_ids': torch.as_tensor(end_ids),
              'idx': torch.as_tensor(idxs)}
    return sample


Advertisement

Answer

See this answer (and give it an upvote): https://stackoverflow.com/a/30245465/10475762

Particularly the line: “In other words and in general, list comprehensions perform faster because suspending and resuming a function’s frame, or multiple functions in other cases, is slower than creating a list on demand.”

So in your case, you’re calling append multiple times each collate, which is called quite a few times in your training/testing/evaluation steps which all adds up. IMO, always avoid for loops whenever you can as it seems to somehow invariably lead to slowdowns.

User contributions licensed under: CC BY-SA
6 People found this is helpful
Advertisement