Source code for vak.eval.eval_

"""High-level function that evaluates trained models."""
from __future__ import annotations

import logging
import pathlib

from .. import models
from ..common import validators
from .frame_classification import eval_frame_classification_model
from .parametric_umap import eval_parametric_umap_model

logger = logging.getLogger(__name__)


[docs] def eval( model_name: str, model_config: dict, dataset_path: str | pathlib.Path, checkpoint_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, labelmap_path: str | pathlib.Path | None = None, batch_size: int | None = None, transform_params: dict | None = None, dataset_params: dict | None = None, split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, device: str | None = None, ) -> None: """Evaluate a trained model. Parameters ---------- model_name : str Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. dataset_path : str, pathlib.Path Path to dataset, e.g., a csv file generated by running ``vak prep``. checkpoint_path : str, pathlib.Path path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path Path to location where .csv files with evaluation metrics should be saved. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. labelmap_path : str, pathlib.Path, optional Path to 'labelmap.json' file. Optional, default is None. batch_size : int, optional. Number of samples per batch fed into model. Optional, default is None. transform_params: dict, optional Parameters for data transform. Passed as keyword arguments. Optional, default is None. dataset_params: dict, optional Parameters for dataset. Passed as keyword arguments. Optional, default is None. split : str split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. spect_scaler_path : str, pathlib.Path path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. Default is None. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is ``vak.transforms.frame_labels.PostProcess`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote``, a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. Notes ----- Note that unlike ``core.predict``, this function can modify ``labelmap`` so that metrics like edit distance are correctly computed, by converting any string labels in ``labelmap`` with multiple characters to (mock) single-character labels, with ``vak.labels.multi_char_labels_to_single_char``. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( (checkpoint_path, labelmap_path, spect_scaler_path), ("checkpoint_path", "labelmap_path", "spect_scaler_path"), ): if path is not None: # because `spect_scaler_path` is optional if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_path) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) try: model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name] except KeyError as e: raise ValueError( f"No model family found for the model name specified: {model_name}" ) from e if model_family == "FrameClassificationModel": eval_frame_classification_model( model_name=model_name, model_config=model_config, dataset_path=dataset_path, checkpoint_path=checkpoint_path, labelmap_path=labelmap_path, output_dir=output_dir, num_workers=num_workers, transform_params=transform_params, dataset_params=dataset_params, split=split, spect_scaler_path=spect_scaler_path, device=device, post_tfm_kwargs=post_tfm_kwargs, ) elif model_family == "ParametricUMAPModel": eval_parametric_umap_model( model_name=model_name, model_config=model_config, dataset_path=dataset_path, checkpoint_path=checkpoint_path, output_dir=output_dir, batch_size=batch_size, num_workers=num_workers, transform_params=transform_params, dataset_params=dataset_params, split=split, device=device, ) else: raise ValueError(f"Model family not recognized: {model_family}")