Source code for vak.models.registry
"""Registry for models.
Makes it possible to register a model declared outside of ``vak``
with a decorator, so that the model can be used at runtime.
"""
from __future__ import annotations
import inspect
from typing import Any, Type
from .base import Model
MODEL_FAMILY_REGISTRY = {}
[docs]
def model_family(family_class: Type) -> None:
"""Decorator that adds a class to the registry of model families."""
if family_class not in Model.__subclasses__():
raise TypeError(
"The ``family_class`` provided to the `vak.models.model_family` decorator"
"must be a subclass of `vak.models.base.Model`, "
f"but the class specified is not: {family_class}. "
f"Subclasses of `vak.models.base.Model` are: {Model.__subclasses__()}"
)
model_family_name = family_class.__name__
if model_family_name in MODEL_FAMILY_REGISTRY:
raise ValueError(
f"Attempted to register a model family with the name '{model_family_name}', "
f"but this name is already in the registry:\n{MODEL_FAMILY_REGISTRY}"
)
MODEL_FAMILY_REGISTRY[model_family_name] = family_class
# need to return class after we register it or we replace it with None
# when this function is used as a decorator
return family_class
MODEL_REGISTRY = {}
[docs]
def register_model(model_class: Type) -> Type:
"""Decorator that registers a model in the model registry.
This function is called by :func:`vak.models.decorator.model`,
that creates a model class from a model definition.
So you will not usually need to use this decorator directly,
and should prefer to use :func:`vak.models.decorator.model` instead.
"""
model_family_classes = list(MODEL_FAMILY_REGISTRY.values())
model_parent_class = inspect.getmro(model_class)[1]
if model_parent_class not in model_family_classes:
raise TypeError(
"The parent class of ``model_class`` passed to the ``model`` decorator "
f"is not recognized as a model family. Class was: {model_class} and "
f"parent is {model_parent_class}, as determined with "
f"``inspect.getmro(model_class)[1]``. "
f"Please specify a class that is a sub-class of a model family. "
f"Valid model family classes are: {model_family_classes}"
)
model_name = model_class.__name__
if model_name in MODEL_REGISTRY:
raise ValueError(
f"Attempted to register a model family with the name '{model_name}', "
f"but this name is already in the registry.\n"
)
MODEL_REGISTRY[model_name] = model_class
# need to return class after we register it or we replace it with None
# when this function is used as a decorator
return model_class
def __getattr__(name: str) -> Any:
"""Module-level __getattr__ function that we use to dynamically determine models."""
if name == "MODEL_FAMILY_FROM_NAME":
model_name_family_name_map = {}
for model_name, model_class in MODEL_REGISTRY.items():
model_parent_class = inspect.getmro(model_class)[1]
family_name = model_parent_class.__name__
model_name_family_name_map[model_name] = family_name
return model_name_family_name_map
elif name == "MODEL_NAMES":
return list(MODEL_REGISTRY.keys())
else:
raise AttributeError(
f"Not an attribute of `vak.models.registry`: {name}"
)