Skip to content
Advertisement

Python argmax of dot product of weighted matrix and vector (mnist)

What does argmax mean in this context? I am following the tutorial in this colab notebook: https://colab.research.google.com/github/chokkan/deeplearningclass/blob/master/mnist.ipynb

for x, y in zip(Xtrain, Ytrain):
        y_pred = np.argmax(np.dot(W, x))

It looks like this is saying that for every record x and its truth value y, in the vectors Xtrain and Ytrain, take the max value of the dot product of the weighted matrix W and the record x. Does this mean it takes the max of the weighted matrix?

It also looks like 1 was appended to the flattened vector:

def image_to_vector(X):
    X = np.reshape(X, (len(X), -1))     # Flatten: (N x 28 x 28) -> (N x 784)
    return np.c_[X, np.ones(len(X))]    # Append 1: (N x 784) -> (N x 785)

Xtrain = image_to_vector(data['train_x'])

Why would that be?

Thank you!

Advertisement

Answer

For simplicity, you can treat it as a sort of y = W * x + bias. Additional column of ones is independent on the input, thus working as bias.

Now, our weight matrix W represents a fully connected layer with 785 (28*28+1) inputs and 10 outputs (7850 weights total). The dot product of W and x is a vector of length 10, containing the scores for each possible class (digit in MNIST case). Applying argmax, we get the index with the highest score (our prediction).

User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement