Skip to content
Advertisement

How to continue training with checkpoints using object_detector.EfficientDetLite4Spec tensorflow lite

Preciously I have set my EfficientDetLite4 model “grad_checkpoint=true” in config.yaml. And it had successfully generated some checkpoints. However, I can’t figure out how to use these checkpoints when I want to continue training based on them.

Every time I train the model it just start from the beginning, not from my checkpoints.

The following picture shows my colab file system structure:

my colab file system structure

The following picture shows where my checkpoints store:

model file system here

The following code shows how I configure the model and how I train with the model.

import numpy as np
import os

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

train_data, validation_data, test_data = 
    object_detector.DataLoader.from_csv('csv_path')

spec = object_detector.EfficientDetLite4Spec(
    uri='/content/model',
    model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
    hparams='grad_checkpoint=true,strategy=gpus',
    epochs=50, batch_size=3,
    steps_per_execution=1, moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, strategy=spec_strategy
)

model = object_detector.create(train_data, model_spec=spec, batch_size=3, 
    train_whole_model=True, validation_data=validation_data)

Advertisement

Answer

The source code is the answer !

I ran into the same problem and found out that the model_dir we pass to the TFLite model Maker’s object detector API is only used for saving the model’s weights: that’s why the API never restores from checkpoints.

Having a look at the source code of this API, I noticed it internally uses the standard model.compile and model.fit functions and it saves the model’s weights through the callbacks parameter of model.fit.
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights !

These are the links to the source code if you want to know more about what some of the functions I use below do:

This is the code:

#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader

#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib



#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is



#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
    model_name='efficientdet-lite4',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2', 
    hparams='', #enable grad_checkpoint=True if you want
    model_dir=checkpoint_dir, 
    epochs=epochs, 
    batch_size=batch_size,
    steps_per_execution=1, 
    moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, 
    strategy=None, 
    tpu=None, 
    gcp_project=None,
    tpu_zone=None, 
    use_xla=False, 
    profile=False, 
    debug=False, 
    tf_random_seed=111111,
    verbose=1
)



#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')




#Create the object detector 
detector = object_detector.create(train_data, 
                                model_spec=spec, 
                                batch_size=batch_size, 
                                train_whole_model=True, 
                                validation_data=validation_data,
                                epochs = epochs,
                                do_train = False
                                )



"""
From here on we use internal/"private" functions of the API,
you can tell because the methods' names begin with an underscore
"""

#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)




#Get the internal keras model    
model = detector.create_model()




#Copy what the API internally does as setup
config = spec.config
config.update(
    dict(
        steps_per_epoch=steps_per_epoch,
        eval_samples=batch_size * validation_steps,
        val_json_file=val_json_file,
        batch_size=batch_size
    )
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()




"""
Here we restore the weights
"""

#Load the weights from the latest checkpoint.
#In my case:
#checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/" 
#specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
try:
  #Option A:
  #load the weights from the last successfully completed epoch
  latest = tf.train.latest_checkpoint(checkpoint_dir) 

  #Option B:
  #load the weights from a specific checkpoint
  #latest = specific_checkpoint_dir

  completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
  model.load_weights(latest)

  print("Checkpoint found {}".format(latest))
except Exception as e:
  print("Checkpoint not found: ", e)



#Retrieve the needed default callbacks
all_callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)



"""
Optional step.
Add callbacks that get executed at the end of every N 
epochs: in this case I want to log the training results to tensorboard.
"""
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
#all_callbacks.append(tensorboard_callback)




"""
Train the model 
"""
model.fit(
    train_ds,
    epochs=epochs,
    initial_epoch=completed_epochs, 
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_ds,
    validation_steps=validation_steps,
    callbacks=all_callbacks #This is for saving checkpoints at the end of every epoch + running the above added callbacks
)




"""
Save/export the trained model
Tip: for integer quantization you simply have to NOT SPECIFY 
the quantization_config parameter of the detector.export method.
In this case it would be: 
detector.export(export_dir = export_dir, tflite_filename='model.tflite')
"""
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)
Advertisement