I am trying to extract the unique elements from a float tensor. I have tried :
out = torch.unique(my_tensor)
However this method only works for int/long tensor. My tensor is quantizied tensor in a non-uniform way, thus its guaranteed to have a small set of float values.
Advertisement
Answer
You could using numpy.unique
instead
import torch import numpy as np t = torch.tensor([1.05, 1.05, 2.01, 2.01, 3.9, 3.9001]) print(np.unique(t.numpy()))
Outputs:
[1.05 2.01 3.9 3.9001]