Skip to content
Advertisement

torch.unique does not work for float tensors

I am trying to extract the unique elements from a float tensor. I have tried :

out = torch.unique(my_tensor)

However this method only works for int/long tensor. My tensor is quantizied tensor in a non-uniform way, thus its guaranteed to have a small set of float values.

Advertisement

Answer

You could using numpy.unique instead

import torch
import numpy as np

t = torch.tensor([1.05, 1.05, 2.01, 2.01, 3.9, 3.9001])
print(np.unique(t.numpy()))

Outputs:

[1.05   2.01   3.9    3.9001]
Advertisement