Source code for vak.eval.frame_classification

"""Function that evaluates trained models in the frame classification family."""
from __future__ import annotations

import json
import logging
import pathlib
from collections import OrderedDict
from datetime import datetime

import joblib
import pandas as pd
import pytorch_lightning as lightning
import torch.utils.data

from .. import datasets, models, transforms
from ..common import validators
from ..datasets.frame_classification import FramesDataset

logger = logging.getLogger(__name__)


[docs] def eval_frame_classification_model( model_name: str, model_config: dict, dataset_path: str | pathlib.Path, checkpoint_path: str | pathlib.Path, labelmap_path: str | pathlib.Path, output_dir: str | pathlib.Path, num_workers: int, transform_params: dict | None = None, dataset_params: dict | None = None, split: str = "test", spect_scaler_path: str | pathlib.Path = None, post_tfm_kwargs: dict | None = None, device: str | None = None, ) -> None: """Evaluate a trained model. Parameters ---------- model_name : str Model name, must be one of vak.models.registry.MODEL_NAMES. model_config : dict Model configuration in a ``dict``, as loaded from a .toml file, and used by the model method ``from_config``. dataset_path : str, pathlib.Path Path to dataset, e.g., a csv file generated by running ``vak prep``. checkpoint_path : str, pathlib.Path Path to directory with checkpoint files saved by Torch, to reload model output_dir : str, pathlib.Path Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str, pathlib.Path Path to 'labelmap.json' file. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. transform_params: dict, optional Parameters for data transform. Passed as keyword arguments. Optional, default is None. dataset_params: dict, optional Parameters for dataset. Passed as keyword arguments. Optional, default is None. split : str Split of dataset on which model should be evaluated. One of {'train', 'val', 'test'}. Default is 'test'. spect_scaler_path : str, pathlib.Path Path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. Default is None. post_tfm_kwargs : dict Keyword arguments to post-processing transform. If None, then no additional clean-up is applied when transforming labeled timebins to segments, the default behavior. The transform used is ``vak.transforms.frame_labels.PostProcess`. Valid keyword argument names are 'majority_vote' and 'min_segment_dur', and should be appropriate values for those arguments: Boolean for ``majority_vote``, a float value for ``min_segment_dur``. See the docstring of the transform for more details on these arguments and how they work. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. Notes ----- Note that unlike ``core.predict``, this function can modify ``labelmap`` so that metrics like edit distance are correctly computed, by converting any string labels in ``labelmap`` with multiple characters to (mock) single-character labels, with ``vak.labels.multi_char_labels_to_single_char``. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( (checkpoint_path, labelmap_path, spect_scaler_path), ("checkpoint_path", "labelmap_path", "spect_scaler_path"), ): if path is not None: # because `spect_scaler_path` is optional if not validators.is_a_file(path): raise FileNotFoundError( f"value for ``{path_name}`` not recognized as a file: {path}" ) dataset_path = pathlib.Path(dataset_path) if not dataset_path.exists() or not dataset_path.is_dir(): raise NotADirectoryError( f"`dataset_path` not found or not recognized as a directory: {dataset_path}" ) # we unpack `frame_dur` to log it, regardless of whether we use it with post_tfm below metadata = datasets.frame_classification.Metadata.from_dataset_path( dataset_path ) frame_dur = metadata.frame_dur logger.info( f"Duration of a frame in dataset, in seconds: {frame_dur}", ) if not validators.is_a_directory(output_dir): raise NotADirectoryError( f"value for ``output_dir`` not recognized as a directory: {output_dir}" ) # ---- get time for .csv file -------------------------------------------------------------------------------------- timenow = datetime.now().strftime("%y%m%d_%H%M%S") # ---------------- load data for evaluation ------------------------------------------------------------------------ if spect_scaler_path: logger.info(f"loading spect scaler from path: {spect_scaler_path}") spect_standardizer = joblib.load(spect_scaler_path) else: logger.info("not using a spect scaler") spect_standardizer = None logger.info(f"loading labelmap from path: {labelmap_path}") with labelmap_path.open("r") as f: labelmap = json.load(f) if transform_params is None: transform_params = {} transform_params.update({"spect_standardizer": spect_standardizer}) item_transform = transforms.defaults.get_default_transform( model_name, "eval", transform_params ) if dataset_params is None: dataset_params = {} val_dataset = FramesDataset.from_dataset_path( dataset_path=dataset_path, split=split, item_transform=item_transform, **dataset_params, ) val_loader = torch.utils.data.DataLoader( dataset=val_dataset, shuffle=False, # batch size 1 because each spectrogram reshaped into a batch of windows batch_size=1, num_workers=num_workers, ) # ---------------- do the actual evaluating ------------------------------------------------------------------------ input_shape = val_dataset.shape # if dataset returns spectrogram reshaped into windows, # throw out the window dimension; just want to tell network (channels, height, width) shape if len(input_shape) == 4: input_shape = input_shape[1:] if post_tfm_kwargs: post_tfm = transforms.frame_labels.PostProcess( timebin_dur=frame_dur, **post_tfm_kwargs, ) else: post_tfm = None model = models.get( model_name, model_config, num_classes=len(labelmap), input_shape=input_shape, labelmap=labelmap, post_tfm=post_tfm, ) logger.info(f"running evaluation for model: {model_name}") model.load_state_dict_from_path(checkpoint_path) # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 if device == "cuda": accelerator = "gpu" else: accelerator = "auto" trainer_logger = lightning.loggers.TensorBoardLogger(save_dir=output_dir) trainer = lightning.Trainer(accelerator=accelerator, logger=trainer_logger) # TODO: check for hasattr(model, test_step) and if so run test # below, [0] because validate returns list of dicts, length of no. of val loaders metric_vals = trainer.validate(model, dataloaders=val_loader)[0] metric_vals = {f"avg_{k}": v for k, v in metric_vals.items()} for metric_name, metric_val in metric_vals.items(): if metric_name.startswith("avg_"): logger.info(f"{metric_name}: {metric_val:0.5f}") # create a "DataFrame" with just one row which we will save as a csv; # the idea is to be able to concatenate csvs from multiple runs of eval row = OrderedDict( [ ("model_name", model_name), ("checkpoint_path", checkpoint_path), ("labelmap_path", labelmap_path), ("spect_scaler_path", spect_scaler_path), ("dataset_path", dataset_path), ] ) # TODO: is this still necessary after switching to Lightning? Stop saying "average"? # order metrics by name to be extra sure they will be consistent across runs row.update( sorted( [(k, v) for k, v in metric_vals.items() if k.startswith("avg_")] ) ) # pass index into dataframe, needed when using all scalar values (a single row) # throw away index below when saving to avoid extra column eval_df = pd.DataFrame(row, index=[0]) eval_csv_path = output_dir.joinpath(f"eval_{model_name}_{timenow}.csv") logger.info(f"saving csv with evaluation metrics at: {eval_csv_path}") eval_df.to_csv( eval_csv_path, index=False ) # index is False to avoid having "Unnamed: 0" column when loading