Preciously I have set my EfficientDetLite4 model “grad_checkpoint=true” in config.yaml. And it had successfully generated some checkpoints. However, I can’t figure out how to use these checkpoints when I want to continue training based on them.
Every time I train the model it just start from the beginning, not from my checkpoints.
The following picture shows my colab file system structure:
The following picture shows where my checkpoints store:
The following code shows how I configure the model and how I train with the model.
import numpy as np import os from tflite_model_maker.config import ExportFormat from tflite_model_maker import model_spec from tflite_model_maker import object_detector import tensorflow as tf assert tf.__version__.startswith('2') tf.get_logger().setLevel('ERROR') from absl import logging logging.set_verbosity(logging.ERROR) train_data, validation_data, test_data = object_detector.DataLoader.from_csv('csv_path') spec = object_detector.EfficientDetLite4Spec( uri='/content/model', model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/', hparams='grad_checkpoint=true,strategy=gpus', epochs=50, batch_size=3, steps_per_execution=1, moving_average_decay=0, var_freeze_expr='(efficientnet|fpn_cells|resample_p6)', tflite_max_detections=25, strategy=spec_strategy ) model = object_detector.create(train_data, model_spec=spec, batch_size=3, train_whole_model=True, validation_data=validation_data)
The source code is the answer !
I ran into the same problem and found out that the model_dir
we pass to the TFLite model Maker’s object detector API is only used for saving the model’s weights: that’s why the API never restores from checkpoints.
Having a look at the source code of this API, I noticed it internally uses the standard model.compile
functions and it saves the model’s weights through the callbacks
parameter of
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights
These are the links to the source code if you want to know more about what some of the functions I use below do:
- Object Detector Documentation
- Object Detector Source Code
- EfficientDetSpec Source Code
- How the TFLite Model Maker API Compiles your Model
This is the code:
#Useful imports import tensorflow as tf from tflite_model_maker.config import QuantizationConfig from tflite_model_maker.config import ExportFormat from tflite_model_maker import model_spec from tflite_model_maker import object_detector from tflite_model_maker.object_detector import DataLoader #Import the same libs that TFLiteModelMaker interally uses from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib #Setup variables batch_size = 6 #or whatever batch size you want epochs = 50 checkpoint_dir = "/content/..." #whatever your checkpoint directory is #Create whichever object detector's spec you want spec = object_detector.EfficientDetLite4Spec( model_name='efficientdet-lite4', uri='', hparams='', #enable grad_checkpoint=True if you want model_dir=checkpoint_dir, epochs=epochs, batch_size=batch_size, steps_per_execution=1, moving_average_decay=0, var_freeze_expr='(efficientnet|fpn_cells|resample_p6)', tflite_max_detections=25, strategy=None, tpu=None, gcp_project=None, tpu_zone=None, use_xla=False, profile=False, debug=False, tf_random_seed=111111, verbose=1 ) #Load you datasets train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv') #Create the object detector detector = object_detector.create(train_data, model_spec=spec, batch_size=batch_size, train_whole_model=True, validation_data=validation_data, epochs = epochs, do_train = False ) """ From here on we use internal/"private" functions of the API, you can tell because the methods' names begin with an underscore """ #Convert the datasets for training train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True) validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False) #Get the internal keras model model = detector.create_model() #Copy what the API internally does as setup config = spec.config config.update( dict( steps_per_epoch=steps_per_epoch, eval_samples=batch_size * validation_steps, val_json_file=val_json_file, batch_size=batch_size ) ) train.setup_model(model, config) #This is the model.compile call basically model.summary() """ Here we restore the weights """ #Load the weights from the latest checkpoint. #In my case: #checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/" #specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35" try: #Option A: #load the weights from the last successfully completed epoch latest = tf.train.latest_checkpoint(checkpoint_dir) #Option B: #load the weights from a specific checkpoint #latest = specific_checkpoint_dir completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted model.load_weights(latest) print("Checkpoint found {}".format(latest)) except Exception as e: print("Checkpoint not found: ", e) #Retrieve the needed default callbacks all_callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds) """ Optional step. Add callbacks that get executed at the end of every N epochs: in this case I want to log the training results to tensorboard. """ #tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1) #all_callbacks.append(tensorboard_callback) """ Train the model """ train_ds, epochs=epochs, initial_epoch=completed_epochs, steps_per_epoch=steps_per_epoch, validation_data=validation_ds, validation_steps=validation_steps, callbacks=all_callbacks #This is for saving checkpoints at the end of every epoch + running the above added callbacks ) """ Save/export the trained model Tip: for integer quantization you simply have to NOT SPECIFY the quantization_config parameter of the detector.export method. In this case it would be: detector.export(export_dir = export_dir, tflite_filename='model.tflite') """ export_dir = "/content/..." #save the tflite wherever you want quant_config = QuantizationConfig.for_float16() #or whatever quantization you want detector.model = model #inject our trained model into the object detector detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)