I have a gradient exploding problem which I couldn’t solve after trying for several days. I implemented a custom message passing graph neural network in TensorFlow which is used to predict a continuous value from graph data. Each graph is associated with one target value. Each node of a graph is represented by a node attribute vector, and the edges between nodes are represented by an edge attribute vector.
Within a message passing layer, node attributes are updated in a certain way (e.g., by aggregating other node/edge attributes), and these updated node attributes are returned.
Now, I managed to figure out where the gradient problem occurs in my code. I have the below snippet.
to_concat = [neighbors_mean, e] z = K.concatenate(to_concat, axis=-1) output = self.Net(z)
Here, neighbors_mean
is the element-wise mean between two node attributes vi
, vj
that form the edge having an edge attribute e
. Net
is a single layer feed-forward network. With this, the training loss suddenly jumps to NaN after about 30 epochs with a batch size of 32. If the batch size is 128, still the gradients explode after about 200 epochs.
I found that, in this case, the gradients explode because of the edge attribute e
. If I didn’t concatenate neighbors_mean
with e
and just used the below code, there would be no gradient explosion.
output = self.Net(neighbors_mean)
Also I can avoid gradient explosion by sending e
through a sigmoid function as follows. But this degrades the performance (final MAE), because the values in e
are mapped to 0-1 range non-linearly. Note that Rectified Linear Unit (ReLU) instead of sigmoid didn’t work.
to_concat = [neighbors_mean, tf.math.sigmoid(e)] z = K.concatenate(to_concat, axis=-1) output = self.Net(z)
Just to mention that e
carries a single value relating to the distance between the two corresponding nodes and this distance is always in the range 0.5-4. There are no large values or NaNs in e
.
I have a custom loss function to train this model, but I found that this is not a problem with loss (other losses also led to the same problem). Below is my custom loss function. Note that although this is a single output regression network, the final layer of my NN has two neurons, relating to the mean and log(sigma) of the prediction.
def robust_loss(y_true, y_pred): """ Computes the robust loss between labels and predictions. """ mean, sigma = tf.split(y_pred, 2, axis=-1) # tried limiting 'sigma' with sigma = tf.clip_by_value(sigma,-4,1.0) but the gradients still explode loss = np.sqrt(2.0) * K.abs(mean - y_true) * K.exp(-sigma) + sigma return K.mean(loss)
I basically tried everything suggested online to avoid gradient explosion.
- Applied gradient clipping – with
Adam(lr, clipnorm=1, clipvalue=5)
and also withtf.clip_by_global_norm(gradients, 1.0)
- My target variables are always scaled
- Weights are initialized with
glorot_uniform
distribution - Applied regularisation to weights
- Tried larger batch sizes (till 256, although delayed gradient explosion happens at some point)
- Tried with reduced learning rate
What am I missing here? I definitely know it has something to do with concatenating e
. But given that 0.5<e<4, why do the gradients explode in this case? This feature e
is important to me. What else can I do to avoid numerical overflow in my model?
Advertisement
Answer
I solved the problem thanks to this cool debugging tool tf.debugging.check_numerics
.
I initially identified concatenating e
was the problem, and then realised the values that get passed onto e
are considerably larger than the values in neighbors_mean
which is concatenated with e
. Once they are concatenated and sent through a neural network (Net()
in my code), I observed some outputs in order of hundreds and slowly reaching thousands as the training progresses.
This is problematic as I have a softmax operation within the message passing layer. Note that softmax calculates an exponential (exi/Σexj). Anything above e709 results in a numerical overflow in Python. This was producing inf
values and eventually everything becoming nan
was the problem in my code. So, this is technically not a gradient exploding problem which is why it couldn’t be solved with gradient clipping.
How did I track the issue?
I put tf.debugging.check_numerics()
snippets under several layers/tensors I thought were producing nan values. Something like this:
tf.debugging.check_numerics(layerN, "LayerN is producing nans!")
This produces an InvalidArgumentError
as soon as the layer outputs become inf
or nan
during training.
Traceback (most recent call last): File "trainer.py", line 506, in <module> worker.train_model() File "trainer.py", line 211, in train_model l, tmae = train_step(*batch) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__ result = self._call(*args, **kwds) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 855, in _call return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2943, in __call__ filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat ctx, args, cancellation_manager=cancellation_manager)) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call ctx=ctx) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute inputs, attrs, num_outputs) tensorflow.python.framework.errors_impl.InvalidArgumentError: LayerN is producing nans! : Tensor had NaN values
Now we know where the problem is.
How to solve the issue
I applied kernel constraints to the neural network weights whose output gets passed onto the softmax function.
layers.Dense(x, name="layer1", kernel_regularizer=regularizers.l2(1e-6), kernel_constraint=min_max_norm(min_value=1e-30, max_value=1.0))
This should make sure that all weights are less than 1 and the layer does not produce large outputs. This resolved the problem without degrading the performance.
Alternatively, one could use the numerically stable implementation of the softmax function.