How do I convert a PyTorch Tensor
into a python list
?
I want to convert a tensor of size [1, 2048, 1, 1]
into a list of 2048 elements. My tensor has floating point values. Is there a solution which also works with other data types such as int?
Advertisement
Answer
Use Tensor.tolist()
e.g:
JavaScriptx81>>> import torch
2>>> a = torch.randn(2, 2)
3>>> a.tolist()
4[[0.012766935862600803, 0.5415473580360413],
5[-0.08909505605697632, 0.7729271650314331]]
6>>> a[0,0].tolist()
70.012766935862600803
8
To remove all dimensions of size 1
, use a.squeeze().tolist()
.
Alternatively, if all but one dimension are of size 1
(or you wish to get a list of every element of the tensor) you may use a.flatten().tolist()
.