Source code for vak.transforms.defaults.get
"""Helper function that gets default transforms for a model."""
from __future__ import annotations
from ... import models
from . import frame_classification, parametric_umap
[docs]
def get_default_transform(
model_name: str,
mode: str,
transform_kwargs: dict,
):
"""Get default transforms for a model,
according to its family and what mode
the model is being used in.
Parameters
----------
model_name : str
Name of model.
mode : str
one of {'train', 'eval', 'predict'}. Determines set of transforms.
Returns
-------
transform, target_transform : callable
one or more vak transforms to be applied to inputs x and, during training, the target y.
If more than one transform, they are combined into an instance of torchvision.transforms.Compose.
Note that when mode is 'predict', the target transform is None.
"""
try:
model_family = models.registry.MODEL_FAMILY_FROM_NAME[model_name]
except KeyError as e:
raise ValueError(
f"No model family found for the model name specified: {model_name}"
) from e
if model_family == "FrameClassificationModel":
return frame_classification.get_default_frame_classification_transform(
mode, transform_kwargs
)
elif model_family == "ParametricUMAPModel":
return parametric_umap.get_default_parametric_umap_transform(
transform_kwargs
)