I have been trying to stack a single LSTM layer on top of Bert embeddings, but whilst my model starts to train it fails on the last batch and throws the following error message:
JavaScript
x
4
1
Node: 'model/tf.reshape/Reshape'
2
Input to reshape is a tensor with 59136 values, but the requested shape has 98304
3
[[{{node model/tf.reshape/Reshape}}]] [Op:__inference_train_function_70500]
4
This is how I build the model and I honestly cannot figure out what is going wrong here:
JavaScript
1
24
24
1
batch_size = 128
2
3
4
bert_preprocess = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
5
bert_encoder = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4', trainable=True)
6
7
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
8
preprocessed_text = bert_preprocess(text_input)
9
outputs = bert_encoder(preprocessed_text) #shape=(None, 768)
10
bert_output = outputs['pooled_output']
11
12
l = tf.reshape(bert_output, [batch_size, 1, 768])
13
14
l = tf.keras.layers.LSTM(32, activation='relu')(l)
15
16
17
l = tf.keras.layers.Dropout(0.1, name='dropout')(l)
18
l = tf.keras.layers.Dense(8, activation='softmax', name="output")(l)
19
model = tf.keras.Model(inputs=[text_input], outputs = [l])
20
21
print(model.summary())
22
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
23
model.fit(x_train, y_train, epochs=1, batch_size = batch_size)
24
this is the full output:
JavaScript
1
101
101
1
Model: "model"
2
__________________________________________________________________________________________________
3
Layer (type) Output Shape Param # Connected to
4
==================================================================================================
5
text (InputLayer) [(None,)] 0 []
6
7
keras_layer (KerasLayer) {'input_mask': (Non 0 ['text[0][0]']
8
e, 128),
9
'input_word_ids':
10
(None, 128),
11
'input_type_ids':
12
(None, 128)}
13
14
keras_layer_1 (KerasLayer) {'sequence_output': 109482241 ['keras_layer[0][0]',
15
(None, 128, 768), 'keras_layer[0][1]',
16
'default': (None, 'keras_layer[0][2]']
17
768),
18
'encoder_outputs':
19
[(None, 128, 768),
20
(None, 128, 768),
21
(None, 128, 768),
22
(None, 128, 768),
23
(None, 128, 768),
24
(None, 128, 768),
25
(None, 128, 768),
26
(None, 128, 768),
27
(None, 128, 768),
28
(None, 128, 768),
29
(None, 128, 768),
30
(None, 128, 768)],
31
'pooled_output': (
32
None, 768)}
33
34
tf.reshape (TFOpLambda) (128, 1, 768) 0 ['keras_layer_1[0][13]']
35
36
lstm (LSTM) (128, 32) 102528 ['tf.reshape[0][0]']
37
38
dropout (Dropout) (128, 32) 0 ['lstm[0][0]']
39
40
output (Dense) (128, 8) 264 ['dropout[0][0]']
41
42
==================================================================================================
43
Total params: 109,585,033
44
Trainable params: 102,792
45
Non-trainable params: 109,482,241
46
__________________________________________________________________________________________________
47
None
48
WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7fc4ff809440> and will run it as-is.
49
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
50
Cause: 'arguments' object has no attribute 'posonlyargs'
51
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
52
WARNING:tensorflow:AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x7fc4ff809440> and will run it as-is.
53
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
54
Cause: 'arguments' object has no attribute 'posonlyargs'
55
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
56
18/19 [===========================>..] - ETA: 25s - loss: 1.5747 - accuracy: 0.5456Traceback (most recent call last):
57
File "bert-test-lstm.py", line 62, in <module>
58
model.fit(x_train, y_train, epochs=1, batch_size = batch_size)
59
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
60
raise e.with_traceback(filtered_tb) from None
61
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 55, in quick_execute
62
inputs, attrs, num_outputs)
63
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:
64
Detected at node 'model/tf.reshape/Reshape' defined at (most recent call last):
65
File "bert-test-lstm.py", line 62, in <module>
66
model.fit(x_train, y_train, epochs=1, batch_size = batch_size)
67
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
68
return fn(*args, **kwargs)
69
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1384, in fit
70
tmp_logs = self.train_function(iterator)
71
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1021, in train_function
72
return step_function(self, iterator)
73
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1010, in step_function
74
outputs = model.distribute_strategy.run(run_step, args=(data,))
75
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 1000, in run_step
76
outputs = model.train_step(data)
77
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/training.py", line 859, in train_step
78
y_pred = self(x, training=True)
79
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
80
return fn(*args, **kwargs)
81
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/base_layer.py", line 1096, in __call__
82
outputs = call_fn(inputs, *args, **kwargs)
83
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
84
return fn(*args, **kwargs)
85
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/functional.py", line 452, in call
86
inputs, training=training, mask=mask)
87
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/functional.py", line 589, in _run_internal_graph
88
outputs = node.layer(*args, **kwargs)
89
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 64, in error_handler
90
return fn(*args, **kwargs)
91
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/engine/base_layer.py", line 1096, in __call__
92
outputs = call_fn(inputs, *args, **kwargs)
93
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 92, in error_handler
94
return fn(*args, **kwargs)
95
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/layers/core/tf_op_layer.py", line 226, in _call_wrapper
96
return self._call_wrapper(*args, **kwargs)
97
File "/Users/user/opt/anaconda3/envs/project/lib/python3.7/site-packages/keras/layers/core/tf_op_layer.py", line 261, in _call_wrapper
98
result = self.function(*args, **kwargs)
99
Node: 'model/tf.reshape/Reshape'
100
Input to reshape is a tensor with 59136 values, but the requested shape has 98304
101
The code runs perfectly fine if I just drop the LSTM and reshape layers – any help is appreciated.
Advertisement
Answer
You should use tf.keras.layers.Reshape in order to reshape bert_output
into a 3D tensor and automatically taking into account the batch dimension.
Simply changing:
JavaScript
1
2
1
l = tf.reshape(bert_output, [batch_size, 1, 768])
2
into:
JavaScript
1
2
1
l = tf.keras.layers.Reshape((1,768))(bert_output)
2
should work.