Skip to content
Advertisement

How to test a trained model saved in .pth.tar files?

I am working with CORnet-Z and I am building a separate test file.

The model seems to be saved as .pth.tar files

            if FLAGS.output_path is not None:
            records.append(results)
            if len(results) > 1:
                pickle.dump(records, open(os.path.join(FLAGS.output_path, 'results.pkl'),
                                          'wb'))

            ckpt_data = {}
            ckpt_data['flags'] = FLAGS.__dict__.copy()
            ckpt_data['epoch'] = epoch
            ckpt_data['state_dict'] = model.state_dict()
            ckpt_data['optimizer'] = trainer.optimizer.state_dict()

            if save_model_secs is not None:
                if time.time() - recent_time > save_model_secs:
                    torch.save(ckpt_data, os.path.join(FLAGS.output_path,
                                                       'latest_checkpoint.pth.tar'))
                    recent_time = time.time()

What would be the best approach to load this model and run evaluation and testing?

Advertisement

Answer

def load_checkpoint(checkpoint, model, optimizer = None):

    if not os.path.exists(checkpoint):
        raise("File does not exists {}".format(checkpoint))

    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint['state_dict'])


    if optimizer:
        optimizer.load_state_dict(checkpoint['optim_dict'])


    return checkpoint

To test a model you need to load the state dictionary of your trained model and optimizer (if applicable). But, if you are resuming training from a point and you are using any sort of scheduler you need to load the scheduler state too.

Advertisement