Source code for vak.common.trainer

from __future__ import annotations

import pathlib

import pytorch_lightning as lightning


[docs] def get_default_train_callbacks( ckpt_root: str | pathlib.Path, ckpt_step: int, patience: int, ): ckpt_callback = lightning.callbacks.ModelCheckpoint( dirpath=ckpt_root, filename="checkpoint", every_n_train_steps=ckpt_step, save_last=True, verbose=True, ) ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" ckpt_callback.FILE_EXTENSION = ".pt" val_ckpt_callback = lightning.callbacks.ModelCheckpoint( monitor="val_acc", dirpath=ckpt_root, save_top_k=1, mode="max", filename="max-val-acc-checkpoint", auto_insert_metric_name=False, verbose=True, ) val_ckpt_callback.FILE_EXTENSION = ".pt" early_stopping = lightning.callbacks.EarlyStopping( mode="max", monitor="val_acc", patience=patience, verbose=True, ) return [ckpt_callback, val_ckpt_callback, early_stopping]
[docs] def get_default_trainer( max_steps: int, log_save_dir: str | pathlib.Path, val_step: int, default_callback_kwargs: dict | None = None, device: str = "cuda", ) -> lightning.Trainer: """Returns an instance of ``lightning.Trainer`` with a default set of callbacks. Used by ``vak.core`` functions.""" if default_callback_kwargs: callbacks = get_default_train_callbacks(**default_callback_kwargs) else: callbacks = None # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 if device == "cuda": accelerator = "gpu" else: accelerator = "auto" logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir) trainer = lightning.Trainer( callbacks=callbacks, val_check_interval=val_step, max_steps=max_steps, accelerator=accelerator, logger=logger, ) return trainer