I don’t know how to explain it correctly, so the title might be misleading. What I want to do is to move columns from a 3d tensor t1 to another 3d tensor t2 according to the indices. There’s a dictionary td, and a (k,v) pair in td means that kth column of t1 will be the vth column of t2
Currently, I’m doing it this way:
for k,v in td.items(): t2[:,:,v] = torch.select(t1, 2, k)
but yes, it’s super slow, as there are millions of them. What would be the best way to do the work?
Advertisement
Answer
Assuming no repeated values then you can use
t2[:,:,list(td.values())] = t1[:,:,list(td.keys())]