I’m trying to recreate a transformer that was written in Pytorch and make it Tensorflow. Everything was going pretty well until each version of MultiHeadAttention started giving extremely different outputs. Both methods are an implementation of multi-headed attention as described in the paper “Attention is all you Need”, so they should be able to achieve the same output.
I’m converting
self_attn = nn.MultiheadAttention(dModel, nheads, dropout=dropout)
to
self_attn = MultiHeadAttention(num_heads=nheads, key_dim=dModel, dropout=dropout)
For my tests, dropout is 0.
I’m calling them with:
self_attn(x,x,x)
where x is a tensor with shape=(10, 128, 50)
As expected from the documentation, the Pytorch version returns a tuple, (the target sequence length, embedding dimension), both with dimensions [10, 128, 50].
I’m having trouble getting the TensorFlow version to do the same thing. With Tensorflow I only get one tensor back, (size [10, 128, 50]) and it looks like neither the target sequence length or embedding dimension tensor from pytorch. Based on the Tensorflow documentation I should be getting something comparable.
How can I get them to operate the same way? I’m guessing I’m doing something wrong with Tensorflow but I can’t figure out what.
Advertisement
Answer
nn.MultiheadAttention
outputs by default tuple with two tensors:
attn_output
— result of self-attention operationattn_output_weights
— attention weights averaged(!) over heads
At the same time tf.keras.layers.MultiHeadAttention
outputs by default only one tensor attention_output
(which corresponds to attn_output
of pytorch). Attention weights of all heads also will be returned if parameter return_attention_scores
is set to True
, like:
output, scores = self_attn(x, x, x, return_attention_scores=True)
Tensor scores
also should be averaged to achieve full correspondence with pytorch:
scores = tf.math.reduce_mean(scores, 1)
While rewriting keep in mind that by default (as in snippet in question) nn.MultiheadAttention
expects input in form (seq_length, batch_size, embed_dim)
, but tf.keras.layers.MultiHeadAttention
expects it in form (batch_size, seq_length, embed_dim)
.