Skip to content
Advertisement

Is there an efficient way to create a random bit mask in Pytorch?

I want to have a random bit mask that has some specified percent of 0s. The function I devised is:

def create_mask(shape, rate):
    """
    The idea is, you take a random permutations of numbers. You then mod then
    mod it by the [number of entries in the bitmask] / [percent of 0s you
    want]. The number of zeros will be exactly the rate of zeros need. You
    can clamp the values for a bitmask.
    """
    mask = torch.randperm(reduce(operator.mul, shape, 1)).float().cuda()
    # Mod it by the percent to get an even dist of 0s.
    mask = torch.fmod(mask, reduce(operator.mul, shape, 1) / rate)
    # Anything not zero should be put to 1
    mask = torch.clamp(mask, 0, 1)
    return mask.view(shape)

To illustrate:

>>> x = create_mask((10, 10), 10)
>>> x

    1     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     0     1     1     1
    0     1     1     1     1     0     1     1     1     1
    0     1     1     1     1     1     1     1     1     1
    1     1     1     1     1     1     1     1     1     0
    1     1     1     1     1     1     1     1     1     1
    1     1     1     0     1     1     1     0     1     1
    0     1     1     1     1     1     1     1     1     1
    1     1     1     0     1     1     0     1     1     1
    1     1     1     1     1     1     1     1     1     1
[torch.cuda.FloatTensor of size 10x10 (GPU 0)]

The main issue I have with this method is it requires the rate to divide the shape. I want a function that accepts an arbitrary decimal and gives approximately rate percent of 0s in the bitmask. Furthermore, I am trying to find a relatively efficient way of doing so. Hence, I would rather not move a numpy array from the CPU to the GPU. Is there an effiecient way of doing so that allows for a decimal rate?

Advertisement

Answer

For anyone running into this, this will create a bitmask with approximately 80% zero’s directly on GPU. (PyTorch 0.3)

torch.cuda.FloatTensor(10, 10).uniform_() > 0.8
User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement