Skip to content
Advertisement

Why does this custom function cost too much time while backward in pytorch?

I’m revising a baseline method in pytorch. But when I add a custom function in the training phase, the cost time of backward increases 4x on a single V100. Here is an example of the custom function:

def batch_function(M, kernel_size=21, sf=2):
    '''
    Input:
        M: b x (h*w) x 2 x 2 torch tensor 
        sf: scale factor
    Output:
        kernel: b x (h*w) x k x k  torch tensor
    '''

    M_t = M.permute(0,1,3,2)  # b x (h*w) x 2 x 2
    INV_SIGMA = torch.matmul(M_t, M).unsqueeze(2).unsqueeze(2)  # b x (h*w) x 1 x 1 x 2 x 2

    X, Y = torch.meshgrid(torch.arange(kernel_size), torch.arange(kernel_size))
    Z = torch.stack((Y, X), dim=2).unsqueeze(3).to(M.device)   # k x k x 2 x 1

    Z = Z.unsqueeze(0).unsqueeze(0) # 1 x 1 x k x k x 2 x 1
    Z_t = Z.permute(0,1,2,3,5,4)  # 1 x 1 x k x k x 1 x 2
    raw_kernel = torch.exp(-0.5 * torch.squeeze(Z_t.matmul(INV_SIGMA).matmul(Z)))  # b x (h*w) x k x k 

    # Normalize
    kernel = raw_kernel / torch.sum(raw_kernel, dim=(2,3)).unsqueeze(-1).unsqueeze(-1)   # b x (h*w) x k x k 

    return kernel

where b is the batch size, 16; h and w are the spatial dimensions, 100; k is equal to 21. I’m not sure if the large dimension of M causes the cost time longer. Why does the cost time longer? And are there other methods to rewrite this code to improve it? I’m new here, so if the problem is not clearly described, please let me know!

Advertisement

Answer

You might be able to get a performance boost on the double tensor multiplication by using torch.einsum:

>>> o = torch.einsum('acdefg,bshigj,kldejm->bsdefm', ZZ_t, INV_SIGMA, ZZ)

The resulting tensor o will be shaped (b, h*w, k, k, 1, 1)


For details on the subscript notation:

  • b: batch dimension.
  • s: ‘s’ for spatial, i.e. the h*w dimension.
  • d and e: the two k dimensions which are paired across ZZ_t and ZZ.

A simple 2D matrix multiplication applying matmul with ij,jk->ik.
Keeping that in mind, we have in your case:

  • A first multiplication: r = ZZ_t@INV_SIGMA
    which does something like *fg,*gj->*fj,
    the asterisk sign * refers to leading dimensions.

  • A second matrix multiplication: r@INV_SIGMA
    which comes down to *fj,*jm->*fm.

Overall, if we combine both, we get directly: *fg,*gj,*jm->*fm.

Finally, I have assigned all other dimensions to random but different subscript letters:

a, c, f, h, i, k, l

Replacing the asterisk above with those notations, we get the following subscript input:

#  *  fg, *  gj, *  jm-> *  fm
# acdefg,bshigj,kldejm->bsdefm
User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement