I have been trying to stack a single LSTM layer on top of Bert embeddings, but whilst my model starts to train it fails on the last batch and throws the following error message:
Node: 'model/tf.reshape/Reshape' Input to reshape is a tensor with 59136 values, but the requested shape has 98304 [[{{node model/tf.reshape/Reshape}}]] [Op:__inference_train_function_70500]
This is how I build the model and I honestly cannot figure out what is going wrong here:
batch_size = 128 bert_preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3') bert_encoder = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4', trainable=True) text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text') preprocessed_text = bert_preprocess(text_input) outputs = bert_encoder(preprocessed_text) #shape=(None, 768) bert_output = outputs['pooled_output'] l = tf.reshape(bert_output, [batch_size, 1, 768]) l = tf.keras.layers.LSTM(32, activation='relu')(l) l = tf.keras.layers.Dropout(0.1, name='dropout')(l) l = tf.keras.layers.Dense(8, activation='softmax', name="output")(l) model = tf.keras.Model(inputs=[text_input], outputs = [l]) print(model.summary()) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=1, batch_size = batch_size)
this is the full output:
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== text (InputLayer) [(None,)] 0 [] keras_layer (KerasLayer) {'input_mask': (Non 0 ['text[0][0]'] e, 128), 'input_word_ids': (None, 128), 'input_type_ids': (None, 128)} keras_layer_1 (KerasLayer) {'sequence_output': 109482241 ['keras_layer[0][0]', (None, 128, 768), 'keras_layer[0][1]', 'default': (None, 'keras_layer[0][2]'] 768), 'encoder_outputs': [(None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768), (None, 128, 768)], 'pooled_output': ( None, 768)} tf.reshape (TFOpLambda) (128, 1, 768) 0 ['keras_layer_1[0][13]'] lstm (LSTM) (128, 32) 102528 ['tf.reshape[0][0]'] dropout (Dropout) (128, 32) 0 ['lstm[0][0]'] output (Dense) (128, 8) 264 ['dropout[0][0]'] ================================================================================================== Total params: 109,585,033 Trainable params: 102,792 Non-trainable params: 109,482,241 __________________________________________________________________________________________________ None WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7fc4ff809440> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: 'arguments' object has no attribute 'posonlyargs' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7fc4ff809440> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: 'arguments' object has no attribute 'posonlyargs' To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert 18/19 [===========================>..] - ETA: 25s - loss: 1.5747 - accuracy: 0.5456Traceback (most recent call last): File "bert-test-lstm.py", line 62, in <module> model.fit(x_train, y_train, epochs=1, batch_size = batch_size) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler raise e.with_traceback(filtered_tb) from None File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 55, in quick_execute inputs, attrs, num_outputs) tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error: Detected at node 'model/tf.reshape/Reshape' defined at (most recent call last): File "bert-test-lstm.py", line 62, in <module> model.fit(x_train, y_train, epochs=1, batch_size = batch_size) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1384, in fit tmp_logs = self.train_function(iterator) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1021, in train_function return step_function(self, iterator) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1010, in step_function outputs = model.distribute_strategy.run(run_step, args=(data,)) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1000, in run_step outputs = model.train_step(data) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 859, in train_step y_pred = self(x, training=True) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/functional.py", line 452, in call inputs, training=training, mask=mask) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph outputs = node.layer(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/layers/core/tf_op_layer.py", line 226, in _call_wrapper return self._call_wrapper(*args, **kwargs) File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/layers/core/tf_op_layer.py", line 261, in _call_wrapper result = self.function(*args, **kwargs) Node: 'model/tf.reshape/Reshape' Input to reshape is a tensor with 59136 values, but the requested shape has 98304
The code runs perfectly fine if I just drop the LSTM and reshape layers – any help is appreciated.
Advertisement
Answer
You should use tf.keras.layers.Reshape in order to reshape bert_output
into a 3D tensor and automatically taking into account the batch dimension.
Simply changing:
l = tf.reshape(bert_output, [batch_size, 1, 768])
into:
l = tf.keras.layers.Reshape((1,768))(bert_output)
should work.