Skip to content
Advertisement

Is there any way to implement early stopping callback for Tensorflow 2 model_main_tf.py?

Hello I’m working on object detection using tensorflow 2 object detection API model_main_tf2.py file normally we can use early stopping callback for model.fit() when we use normally but when i tried to training with pipeline config model_main_tf2.py file and .config file I’m not able to implement it because I’m unable to locate model.fit() in the main file so please is there any way i can implement the early stopping for model_main_tf2.py file please help me.

Link to the file: https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py

Advertisement

Answer

I had a look inside the model_main_tf2.py file. Let’s take the following piece of code:

JavaScript

Instead of executing the training through fit it is used a custom training loop. In the code above is called the function that executes the training operation. model_lib_v2 is just another file of the repo that you’ve linked.

If you have a look at the train_loop function, you’ll see that at some point is executed the following code:

JavaScript

GradientTape basically computes the gradients needed to update the weights of the model during the training phase. I won’t go into much details, but if you are interested you can have a look at the linked documentation.

Now, you are interested in adding an early stopping callback, but you don’t have a fit. You can still add early stopping, but in a different way.

You can follow a strategy like the one below (Refer to this tutorial by tensorflow for the full code):

JavaScript
User contributions licensed under: CC BY-SA
9 People found this is helpful
Advertisement