Skip to content
Advertisement

Trouble implementing “concurrent” softmax function from paper (PyTorch)

I am trying to implement the so called ‘concurrent’ softmax function given in the paper “Large-Scale Object Detection in the Wild from Imbalanced Multi-Labels”. Below is the definition of the concurrent softmax:

enter image description here

NOTE: I have left the (1-rij) term out for the time being because I don’t think it applies to my problem given that my training dataset has a different type of labeling compared to the paper.

To keep it simple for myself I am starting off by implementing it in a very inefficient, but easy to follow, way using for loops. However, the output I get seems wrong to me. Below is the code I am using:

# here is a one-hot encoded vector for the multi-label classification
# the image thus has 2 correct labels out of a possible 3 classes
y = [0, 1, 1]

# these are some made up logits that might come from the network.
vec = torch.tensor([0.2, 0.9, 0.7])

def concurrent_softmax(vec, y):
    for i in range(len(vec)):
        zi = torch.exp(vec[i])
        sum_over_j = 0
        for j in range(len(y)):
            sum_over_j += (1-y[j])*torch.exp(vec[j])

        out = zi / (sum_over_j + zi)
        yield out

for result in concurrent_softmax(vec, y):
    print(result)

From this implementation I have realized that, no matter what value I give to the first logit in ‘vec’ I will always get an output of 0.5 (because it essentially always calculates zi / (zi+zi)). This seems like a major problem, because I would expect the value of the logits to have some influence on the resulting concurrent-softmax value. Is there a problem in my implementation then, or is this behaviour of the function correct and there is something theoretically that I am not understanding?

Advertisement

Answer

This is the expected behaviour given y[i]=1 for all other i.

Note you can simplify the summation with a dot product:

y = torch.tensor(y)

def concurrent_softmax(z, y):
    sum_over_j = torch.dot((torch.ones(len(y)) - y), torch.exp(z))

    for zi in z:
        numerator = torch.exp(zi)
        denominator = sum_over_j + numerator
        yield numerator / denominator
User contributions licensed under: CC BY-SA
2 People found this is helpful
Advertisement