My Task:
I’m trying to calculate the pair-wise distance between every two samples in two big tensors (for k-Nearest-Neighbours), That is – given tensor test
with shape (b1,c,h,w)
and tensor train
with shape (b2,c,h,w)
, I need || test[i]-train[j] ||
for every i
,j
. (where both test[i]
and train[j]
have shape (c,h,w)
, as those are sampes in the batch).
The Problem
both train
and test
are very big, so I can’t fit them into RAM
My current solution
For a start, I did not construct these tensors in one go – As I build them, I split the data Tensor and save them separately to memory, so I end up with files {Testtest_1,...,Testtest_n}
and {Traintrain_1,...,Traintrain_m}
.
Then, I load in a nested for
loop every Testtest_i
and Traintrain_j
, calculate the current distance, and save it.
This semi-pseudo-code might explain
test_files = [f'Testtest_{i}' for i in range(n)] train_files = [f'Traintrain_{j}' for j in range(m)] dist = lambda t1,t2: torch.cdist(t1.flatten(1), t2.flatten(1)) all_distances = [] for test_i in test_files: test_i = torch.load(test_i) # shape (c,h,w) dist_of_i_from_all_j = torch.Tensor([]) for train_j in train_files: train_j = torch.load(train_j) # shape (c,h,w) dist_of_i_from_all_j = torch.cat((dist_of_i_from_all_j, dist(test_i,train_j)) all_distances.append(dist_of_i_from_all_j) # and now I can take the k-smallest from all_distances
What I thought might work
I came across FAISS repository, in which they explain that this process can be sped up (maybe?) using their solutions, though I’m not quite sure how. Regardless, any approach would help!
Advertisement
Answer
Did you check the FAISS documentation?
If what you need is the L2 norm (torch.cidst
uses p=2
as default parameter) then it is quite straightforward. Code below is an adaptation of the FAISS docs to your example:
import faiss import numpy as np d = 64 # dimension nb = 100000 # database size nq = 10000 # nb of queries np.random.seed(1234) # make reproducible x_test = np.random.random((nb, d)).astype('float32') x_test[:, 0] += np.arange(nb) / 1000. x_train = np.random.random((nq, d)).astype('float32') x_train[:, 0] += np.arange(nq) / 1000. index = faiss.IndexFlatL2(d) # build the index print(index.is_trained) index.add(x_test) # add vectors to the index print(index.ntotal) k= 100 # take the 100 closest neighbors D, I = index.search(x_train, k) # actual search print(I[:5]) # neighbors of the 100 first queries print(I[-5:]) # neighbors of the 100 last queries