Skip to content
Advertisement

reorder columns in a tensor according to a dictionary

I don’t know how to explain it correctly, so the title might be misleading. What I want to do is to move columns from a 3d tensor t1 to another 3d tensor t2 according to the indices. There’s a dictionary td, and a (k,v) pair in td means that kth column of t1 will be the vth column of t2

Currently, I’m doing it this way:

for k,v in td.items():
    t2[:,:,v] = torch.select(t1, 2, k)

but yes, it’s super slow, as there are millions of them. What would be the best way to do the work?

Advertisement

Answer

Assuming no repeated values then you can use

t2[:,:,list(td.values())] = t1[:,:,list(td.keys())]
User contributions licensed under: CC BY-SA
4 People found this is helpful
Advertisement