Skip to content
Advertisement

Convert PyTorch tensor to python list

How do I convert a PyTorch Tensor into a python list?

I want to convert a tensor of size [1, 2048, 1, 1] into a list of 2048 elements. My tensor has floating point values. Is there a solution which also works with other data types such as int?

Advertisement

Answer

Use Tensor.tolist() e.g:

>>> import torch
>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
 [-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803

To remove all dimensions of size 1, use a.squeeze().tolist().

Alternatively, if all but one dimension are of size 1 (or you wish to get a list of every element of the tensor) you may use a.flatten().tolist().

User contributions licensed under: CC BY-SA
7 People found this is helpful
Advertisement