Skip to content
Advertisement

Gradient exploding problem in a graph neural network

I have a gradient exploding problem which I couldn’t solve after trying for several days. I implemented a custom message passing graph neural network in TensorFlow which is used to predict a continuous value from graph data. Each graph is associated with one target value. Each node of a graph is represented by a node attribute vector, and the edges between nodes are represented by an edge attribute vector.

Within a message passing layer, node attributes are updated in a certain way (e.g., by aggregating other node/edge attributes), and these updated node attributes are returned.

Now, I managed to figure out where the gradient problem occurs in my code. I have the below snippet.

JavaScript

Here, neighbors_mean is the element-wise mean between two node attributes vi, vj that form the edge having an edge attribute e. Net is a single layer feed-forward network. With this, the training loss suddenly jumps to NaN after about 30 epochs with a batch size of 32. If the batch size is 128, still the gradients explode after about 200 epochs.

I found that, in this case, the gradients explode because of the edge attribute e. If I didn’t concatenate neighbors_mean with e and just used the below code, there would be no gradient explosion.

JavaScript

Also I can avoid gradient explosion by sending e through a sigmoid function as follows. But this degrades the performance (final MAE), because the values in e are mapped to 0-1 range non-linearly. Note that Rectified Linear Unit (ReLU) instead of sigmoid didn’t work.

JavaScript

Just to mention that e carries a single value relating to the distance between the two corresponding nodes and this distance is always in the range 0.5-4. There are no large values or NaNs in e.

I have a custom loss function to train this model, but I found that this is not a problem with loss (other losses also led to the same problem). Below is my custom loss function. Note that although this is a single output regression network, the final layer of my NN has two neurons, relating to the mean and log(sigma) of the prediction.

JavaScript

I basically tried everything suggested online to avoid gradient explosion.

  1. Applied gradient clipping – with Adam(lr, clipnorm=1, clipvalue=5) and also with tf.clip_by_global_norm(gradients, 1.0)
  2. My target variables are always scaled
  3. Weights are initialized with glorot_uniform distribution
  4. Applied regularisation to weights
  5. Tried larger batch sizes (till 256, although delayed gradient explosion happens at some point)
  6. Tried with reduced learning rate

What am I missing here? I definitely know it has something to do with concatenating e. But given that 0.5<e<4, why do the gradients explode in this case? This feature e is important to me. What else can I do to avoid numerical overflow in my model?

Advertisement

Answer

I solved the problem thanks to this cool debugging tool tf.debugging.check_numerics.

I initially identified concatenating e was the problem, and then realised the values that get passed onto e are considerably larger than the values in neighbors_mean which is concatenated with e. Once they are concatenated and sent through a neural network (Net() in my code), I observed some outputs in order of hundreds and slowly reaching thousands as the training progresses.

This is problematic as I have a softmax operation within the message passing layer. Note that softmax calculates an exponential (exi/Σexj). Anything above e709 results in a numerical overflow in Python. This was producing inf values and eventually everything becoming nan was the problem in my code. So, this is technically not a gradient exploding problem which is why it couldn’t be solved with gradient clipping.

How did I track the issue?

I put tf.debugging.check_numerics() snippets under several layers/tensors I thought were producing nan values. Something like this:

JavaScript

This produces an InvalidArgumentError as soon as the layer outputs become inf or nan during training.

JavaScript

Now we know where the problem is.

How to solve the issue

I applied kernel constraints to the neural network weights whose output gets passed onto the softmax function.

JavaScript

This should make sure that all weights are less than 1 and the layer does not produce large outputs. This resolved the problem without degrading the performance.

Alternatively, one could use the numerically stable implementation of the softmax function.

Advertisement