I’m trying to implement some calculation, but I can’t figure how to vectorize my code and not using loops.
Let me explain: I have a matrix M[N,C]
of either 0
or 1
. Another matrix Y[N,1]
containing values of [0,C-1]
(My classes). Another matrix ds[N,M]
which is my dataset.
My output matrix is of size grad[M,C]
and should be calculated as follow: I’ll explain for grad[:,0]
, same logic for any other column.
For each row(sample) in ds
, if Y[that sample] != 0
(The current column of output matrix) and M[that sample, 0] > 0
, then grad[:,0] += ds[that sample]
If Y[that sample] == 0
, then grad[:,0] -= (ds[that sample] * <Num of non zeros in M[that sample,:]>
)
Here is my iterative approach:
for i in range(M.size(dim=1)): for j in range(ds.size(dim=0)): if y[j] == i: grad[:,i] = grad[:,i] - (ds[j,:].T * sum(M[j,:])) else: if M[j,i] > 0: grad[:,i] = grad[:,i] + ds[j,:].T
Advertisement
Answer
Since you are dealing with three dimensions n
, m
, and c
(in lowercase to avoid ambiguity), it can be useful to change the shape of all your tensors to (n, m, c)
, by replicating their values over the missing dimension (e.g. M(m, c)
becomes M(n, m, c)
).
However, you can skip the explicit replication and use broadcasting, so it is sufficient to unsqueeze the missing dimension (e.g. M(m, c)
becomes M(1, m, c)
.
Given these considerations, the vectorization of your code becomes as follows
cond = y.unsqueeze(2) == torch.arange(M.size(dim=1)).unsqueeze(0) pos = ds.unsqueeze(2) * M.unsqueeze(1) * cond neg = ds.unsqueeze(2) * M.unsqueeze(1).sum(dim=0, keepdim=True) * ~cond grad += (pos - neg).sum(dim=0)
Here is a small test to check the validity of the solution
import torch n, m, c = 11, 5, 7 y = torch.randint(c, size=(n, 1)) ds = torch.rand(n, m) M = torch.randint(2, size=(n, c)) grad = torch.rand(m, c) def slow_grad(y, ds, M, grad): for i in range(M.size(dim=1)): for j in range(ds.size(dim=0)): if y[j] == i: grad[:,i] = grad[:,i] - (ds[j,:].T * sum(M[j,:])) else: if M[j,i] > 0: grad[:,i] = grad[:,i] + ds[j,:].T return grad def fast_grad(y, ds, M, grad): cond = y.unsqueeze(2) == torch.arange(M.size(dim=1)).unsqueeze(0) pos = ds.unsqueeze(2) * M.unsqueeze(1) * cond neg = ds.unsqueeze(2) * M.unsqueeze(1).sum(dim=0, keepdim=True) * ~cond grad += (pos - neg).sum(dim=0) return grad # Assert equality of all elements function outputs, throws an exception if false assert torch.all(slow_grad(y, ds, M, grad) == fast_grad(y, ds, M, grad))
Feel free to test on other cases as well!