Skip to content
Advertisement

Tensorflow Keras Tensor Multiplication with None as First Dimension

I’m using TensorFlow Keras backend and I have two tensors a, b of the same shape: (None, 4, 7), where None represents the batch dimension.

I want to do matrix multiplication, and I’m expecting a result of (None, 4, 4).
i.e. For each batch, do one matmul: (4,7)ยท(7,4) = (4,4)

Here’s my code —

K.dot(a, K.reshape(b, (-1, 7, 4)))

This code gives a tensor of shape (None, 4, None, 4)

I’d like to know how does high-dimension matrix multiplication work? What’s the right way to do this?

Advertisement

Answer

IIUC, you can either use tf.matmul directly as part of your model and transpose b or explicitly wrap the operation in a Lambda layer:

import tensorflow as tf

a = tf.keras.layers.Input((4, 7))
b = tf.keras.layers.Input((4, 7))
output = tf.matmul(a, b, transpose_b=True)
model = tf.keras.Model([a, b], output)
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_15 (InputLayer)          [(None, 4, 7)]       0           []                               
                                                                                                  
 input_16 (InputLayer)          [(None, 4, 7)]       0           []                               
                                                                                                  
 tf.linalg.matmul_2 (TFOpLambda  (None, 4, 4)        0           ['input_15[0][0]',               
 )                                                                'input_16[0][0]']               
                                                                                                  
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________

Or

import tensorflow as tf

a = tf.keras.layers.Input((4, 7))
b = tf.keras.layers.Input((4, 7))
output = tf.keras.layers.Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([a, b])
model = tf.keras.Model([a, b], output)
model.summary()
User contributions licensed under: CC BY-SA
4 People found this is helpful
Advertisement