I am trying to extract the unique elements from a float tensor. I have tried :
JavaScript
x
2
1
out = torch.unique(my_tensor)
2
However this method only works for int/long tensor. My tensor is quantizied tensor in a non-uniform way, thus its guaranteed to have a small set of float values.
Advertisement
Answer
You could using numpy.unique
instead
JavaScript
1
6
1
import torch
2
import numpy as np
3
4
t = torch.tensor([1.05, 1.05, 2.01, 2.01, 3.9, 3.9001])
5
print(np.unique(t.numpy()))
6
Outputs:
JavaScript
1
2
1
[1.05 2.01 3.9 3.9001]
2