Skip to content
Advertisement

Is there any way to implement early stopping callback for Tensorflow 2 model_main_tf.py?

Hello I’m working on object detection using tensorflow 2 object detection API model_main_tf2.py file normally we can use early stopping callback for model.fit() when we use normally but when i tried to training with pipeline config model_main_tf2.py file and .config file I’m not able to implement it because I’m unable to locate model.fit() in the main file so please is there any way i can implement the early stopping for model_main_tf2.py file please help me.

Link to the file: https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py

Advertisement

Answer

I had a look inside the model_main_tf2.py file. Let’s take the following piece of code:

model_lib_v2.train_loop(
          pipeline_config_path=FLAGS.pipeline_config_path,
          model_dir=FLAGS.model_dir,
          train_steps=FLAGS.num_train_steps,
          use_tpu=FLAGS.use_tpu,
          checkpoint_every_n=FLAGS.checkpoint_every_n,
          record_summaries=FLAGS.record_summaries) 

Instead of executing the training through fit it is used a custom training loop. In the code above is called the function that executes the training operation. model_lib_v2 is just another file of the repo that you’ve linked.

If you have a look at the train_loop function, you’ll see that at some point is executed the following code:

with tf.GradientTape() as tape:
    losses_dict, _ = _compute_losses_and_predictions_dicts(
        detection_model, features, labels,
        training_step=training_step,
        add_regularization_loss=add_regularization_loss)

    losses_dict = normalize_dict(losses_dict, num_replicas)

  trainable_variables = detection_model.trainable_variables

  total_loss = losses_dict['Loss/total_loss']
  gradients = tape.gradient(total_loss, trainable_variables)

GradientTape basically computes the gradients needed to update the weights of the model during the training phase. I won’t go into much details, but if you are interested you can have a look at the linked documentation.

Now, you are interested in adding an early stopping callback, but you don’t have a fit. You can still add early stopping, but in a different way.

You can follow a strategy like the one below (Refer to this tutorial by tensorflow for the full code):

epochs = 100
patience = 5  # you can play with this values to obtain the best config
wait = 0
best = 0
for epoch in range(epochs):
    # training (calling the function that holds the GradientTape
    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
    
    # updating the metrics after the whole training loop on a single epoch         
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))
    
    # evaluating the model just trained in a new epoch, on the validation data
    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    
    # updating the metrics for validation
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
User contributions licensed under: CC BY-SA
9 People found this is helpful
Advertisement