I saw this line of code in an implementation of BiLSTM:
batch_output = batch_output[batch_mask, ...]
I assume this is some kind of “masking” operation, but found little information on Google about the meaning of ...
. Please help:).
Original Code:
class BiLSTM(nn.Module): def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim, num_layers, bidirectional, dropout, pretrained=None): # irrelevant code .......... def forward(self, batch_input, batch_input_lens, batch_mask): batch_size, padding_length = batch_input.size() batch_input = self.word_embeds(batch_input) # size: #batch * padding_length * embedding_dim batch_input = rnn_utils.pack_padded_sequence( batch_input, batch_input_lens, batch_first=True) batch_output, self.hidden = self.lstm(batch_input, self.hidden) self.repackage_hidden(self.hidden) batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True) batch_output = batch_output.contiguous().view(batch_size * padding_length, -1) ####### HERE ########## batch_output = batch_output[batch_mask, ...] ######################### out = self.hidden2tag(batch_output) return out
Advertisement
Answer
I assume that batch_mask
is a boolean tensor. In that case, batch_output[batch_mask]
performs a boolean indexing that selects the elements corresponding to True
in batch_mask
.
...
is usually referred as ellipsis, and in the case of PyTorch (but also other NumPy-like libraries), it is a shorthand for avoiding repeating the column operator (:
) multiple times. For example, given a tensor
v
, with v.shape
equal to (2, 3, 4)
, the expression v[1, :, :]
can be rewritten as v[1, ...]
.
I performed some tests and using either batch_output[batch_mask, ...]
or batch_output[batch_mask]
seems to work identically:
t = torch.arange(24).reshape(2, 3, 4) # mask.shape == (2, 3) mask = torch.tensor([[False, True, True], [True, False, False]]) print(torch.all(t[mask] == t[mask, ...])) # returns True