Preciously I have set my EfficientDetLite4 model “grad_checkpoint=true” in config.yaml. And it had successfully generated some checkpoints. However, I can’t figure out how to use these checkpoints when I want to continue training based on them.
Every time I train the model it just start from the beginning, not from my checkpoints.
The following picture shows my colab file system structure:
The following picture shows where my checkpoints store:
The following code shows how I configure the model and how I train with the model.
import numpy as np import os from tflite_model_maker.config import ExportFormat from tflite_model_maker import model_spec from tflite_model_maker import object_detector import tensorflow as tf assert tf.__version__.startswith('2') tf.get_logger().setLevel('ERROR') from absl import logging logging.set_verbosity(logging.ERROR) train_data, validation_data, test_data = object_detector.DataLoader.from_csv('csv_path') spec = object_detector.EfficientDetLite4Spec( uri='/content/model', model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/', hparams='grad_checkpoint=true,strategy=gpus', epochs=50, batch_size=3, steps_per_execution=1, moving_average_decay=0, var_freeze_expr='(efficientnet|fpn_cells|resample_p6)', tflite_max_detections=25, strategy=spec_strategy ) model = object_detector.create(train_data, model_spec=spec, batch_size=3, train_whole_model=True, validation_data=validation_data)
Advertisement
Answer
The source code is the answer !
I ran into the same problem and found out that the model_dir
we pass to the TFLite model Maker’s object detector API is only used for saving the model’s weights: that’s why the API never restores from checkpoints.
Having a look at the source code of this API, I noticed it internally uses the standard model.compile
and model.fit
functions and it saves the model’s weights through the callbacks
parameter of model.fit
.
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights
!
These are the links to the source code if you want to know more about what some of the functions I use below do:
- Object Detector Documentation
- Object Detector Source Code
- EfficientDetSpec Source Code
- How the TFLite Model Maker API Compiles your Model
This is the code:
#Useful imports import tensorflow as tf from tflite_model_maker.config import QuantizationConfig from tflite_model_maker.config import ExportFormat from tflite_model_maker import model_spec from tflite_model_maker import object_detector from tflite_model_maker.object_detector import DataLoader #Import the same libs that TFLiteModelMaker interally uses from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib #Setup variables batch_size = 6 #or whatever batch size you want epochs = 50 checkpoint_dir = "/content/..." #whatever your checkpoint directory is #Create whichever object detector's spec you want spec = object_detector.EfficientDetLite4Spec( model_name='efficientdet-lite4', uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2', hparams='', #enable grad_checkpoint=True if you want model_dir=checkpoint_dir, epochs=epochs, batch_size=batch_size, steps_per_execution=1, moving_average_decay=0, var_freeze_expr='(efficientnet|fpn_cells|resample_p6)', tflite_max_detections=25, strategy=None, tpu=None, gcp_project=None, tpu_zone=None, use_xla=False, profile=False, debug=False, tf_random_seed=111111, verbose=1 ) #Load you datasets train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv') #Create the object detector detector = object_detector.create(train_data, model_spec=spec, batch_size=batch_size, train_whole_model=True, validation_data=validation_data, epochs = epochs, do_train = False ) """ From here on we use internal/"private" functions of the API, you can tell because the methods' names begin with an underscore """ #Convert the datasets for training train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True) validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False) #Get the internal keras model model = detector.create_model() #Copy what the API internally does as setup config = spec.config config.update( dict( steps_per_epoch=steps_per_epoch, eval_samples=batch_size * validation_steps, val_json_file=val_json_file, batch_size=batch_size ) ) train.setup_model(model, config) #This is the model.compile call basically model.summary() """ Here we restore the weights """ #Load the weights from the latest checkpoint. #In my case: #checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/" #specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35" try: #Option A: #load the weights from the last successfully completed epoch latest = tf.train.latest_checkpoint(checkpoint_dir) #Option B: #load the weights from a specific checkpoint #latest = specific_checkpoint_dir completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted model.load_weights(latest) print("Checkpoint found {}".format(latest)) except Exception as e: print("Checkpoint not found: ", e) #Retrieve the needed default callbacks all_callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds) """ Optional step. Add callbacks that get executed at the end of every N epochs: in this case I want to log the training results to tensorboard. """ #tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1) #all_callbacks.append(tensorboard_callback) """ Train the model """ model.fit( train_ds, epochs=epochs, initial_epoch=completed_epochs, steps_per_epoch=steps_per_epoch, validation_data=validation_ds, validation_steps=validation_steps, callbacks=all_callbacks #This is for saving checkpoints at the end of every epoch + running the above added callbacks ) """ Save/export the trained model Tip: for integer quantization you simply have to NOT SPECIFY the quantization_config parameter of the detector.export method. In this case it would be: detector.export(export_dir = export_dir, tflite_filename='model.tflite') """ export_dir = "/content/..." #save the tflite wherever you want quant_config = QuantizationConfig.for_float16() #or whatever quantization you want detector.model = model #inject our trained model into the object detector detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)