Skip to content
Advertisement

What does Tensor[batch_mask, …] do?

I saw this line of code in an implementation of BiLSTM:

batch_output = batch_output[batch_mask, ...]

I assume this is some kind of “masking” operation, but found little information on Google about the meaning of .... Please help:).

Original Code:

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
                 num_layers, bidirectional, dropout, pretrained=None):
         # irrelevant code ..........

    def forward(self, batch_input, batch_input_lens, batch_mask):
        batch_size, padding_length = batch_input.size()
        batch_input = self.word_embeds(batch_input)  # size: #batch * padding_length * embedding_dim
        batch_input = rnn_utils.pack_padded_sequence(
            batch_input, batch_input_lens, batch_first=True)
        batch_output, self.hidden = self.lstm(batch_input, self.hidden)
        self.repackage_hidden(self.hidden)
        batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)
        batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)
        
        #######  HERE  ##########
        batch_output = batch_output[batch_mask, ...]
        #########################

        out = self.hidden2tag(batch_output)
        return out

Advertisement

Answer

I assume that batch_mask is a boolean tensor. In that case, batch_output[batch_mask] performs a boolean indexing that selects the elements corresponding to True in batch_mask.

... is usually referred as ellipsis, and in the case of PyTorch (but also other NumPy-like libraries), it is a shorthand for avoiding repeating the column operator (:) multiple times. For example, given a tensor v, with v.shape equal to (2, 3, 4), the expression v[1, :, :] can be rewritten as v[1, ...].

I performed some tests and using either batch_output[batch_mask, ...] or batch_output[batch_mask] seems to work identically:

t = torch.arange(24).reshape(2, 3, 4)

# mask.shape == (2, 3)
mask = torch.tensor([[False, True, True], [True, False, False]])

print(torch.all(t[mask] == t[mask, ...]))  # returns True
User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement