Save best checkpoint with fixed name

Hi,
I am training an ASR model (Tokeniser / LM / ASR). I am trying several changes in the LM and see how they impact the ASR. Is there any way to save the best LM checkpoint with a fixed name so the ASR can load it without the need to change the ASR’s YALM every time?

I.e., After finishing the LM training, how can I save the best checkpoint in LM/save/<lm_model_version>/best_ckpt
so, the pretrainer in the ASR yalm can always call pretrained_lm_file=../LM/save/<lm_model_version>/best_ckpt/model_ckpt

Thank you so much
Cheers

That’s a question for our checkpointer master @Gastron

I found a workaround. It is not pretty but do what I need.
In the LM experiment, I override the on_evaluation_start method, so after loading the best checkpoint, I save the path to the model in a file.

def on_evaluate_start(self, max_key=None, min_key=None):
        if self.checkpointer is not None:
            checkpoint = self.checkpointer.recover_if_possible(
                max_key=max_key,
                min_key=min_key,
                device=torch.device(self.device),
            )
        best_paramfile = str(checkpoint.paramfiles["model"])
        with open(f"{self.hparams.output_folder}/best_ckpt", "w") as best_file:
            best_file.write(best_paramfile)

Then, I read the path and run the ASR experiment using a bash script, passing the path to the parameter lm_pretrained.

You can save the full checkpoint with a fixed name with checkpointer.save_checkpoint(name=...). Note that this will still change the name to CKPT-<name>. Or you can also use a checkpointer to find the best checkpoint automatically:

checkpointer = Checkpointer(path_to_savedir)
ckpt = checkpointer.find_checkpoint(min_key=...)

@Gastron thank you for your response. In the end, I preferred to save the path to the best model instead of the whole checkpoint. This because I need the path to the best LM ckpt to dynamically pass it to the lm_pretrained parameter when running the ASR exp. If I save the whole best checkpoint at the end of the LM training. I would end with 2 versions of the best checkpoint, which is not efficient in terms of storage.

@gaston, @titouan.parcollet , thank you again for your help.
My final approach was to save the best checkpoint at the start of the evaluation.

def on_evaluate_start(self, max_key=None, min_key=None):
        logging.info("Starting Evaluation...")
        if self.checkpointer is not None:
            self.checkpointer.recover_if_possible(
                max_key=max_key,
                min_key=min_key,
                device=torch.device(self.device),
            )
        try:
            self.checkpointer.save_checkpoint(name="best_ckpt")
        except FileExistsError:
             logging.info("Best checkpoint saved in previous run...(ignoring)")

And, in case I need to rerun the evaluation, I add a predicate in the on_fit_start() method to ignore this checkpoint

def on_fit_start(self):
        """Gets called at the beginning of ``fit()``, on multiple processes
        if ``distributed_count > 0`` and backend is ddp.
        Default implementation compiles the jit modules, initializes
        optimizers, and loads the latest checkpoint to resume training.
        """
        # Run this *after* starting all processes since jit modules cannot be
        # pickled.
        self._compile_jit()

        # Wrap modules with parallel backend after jit
        self._wrap_distributed()

        # Initialize optimizers after parameters are configured
        self.init_optimizers()

        # Load latest checkpoint to resume training if interrupted
        if self.checkpointer is not None:
            self.checkpointer.recover_if_possible(
                ckpt_predicate=lambda ckpt: "best_ckpt" not in str(ckpt.paramfiles["model"]),
                device=torch.device(self.device)
            )

In this way, I will always call the ASR using

pretrained_lm_file: ../LM/<save_path>/CKPT+best_ckpt/model.ckpt

Hope that this makes sense

1 Like