Source code for vak.prep.sequence_dataset
"""Helper functions for datasets annotated as sequences."""
from __future__ import annotations
import numpy as np
import numpy.typing as npt
import pandas as pd
from ..common import annotation
[docs]
def where_unlabeled_segments(dataset_df: pd.DataFrame) -> npt.NDArray:
"""Returns a Boolean array that is True where
a sequence has unlabeled segments
in a dataset of annotated sequences.
Parameters
----------
dataset_df : pandas.DataFrame
A dataframe representing a source dataset of audio signals
or spectrograms, as returned by
:func:`vak.prep.audio_dataset.prep_audio_dataset` or
:func:`vak.prep.spectrogram_dataset.prep_spectrogram_dataset`.
Returns
-------
where_unlabeled : numpy.ndarray
Vector with Boolean dtype, where a True
element indicates that the
annotated sequence indexed by this has
segments that are unlabeled.
"""
annots = annotation.from_df(dataset_df)
durations = dataset_df.duration.values
has_unlabeled_list = []
for annot, duration in zip(annots, durations):
has_unlabeled_list.append(annotation.has_unlabeled(annot, duration))
return np.array(has_unlabeled_list).astype(bool)
[docs]
def has_unlabeled_segments(dataset_df: pd.DataFrame) -> bool:
r"""Returns True if a dataset annotated as sequences
has segments that are unlabeled in any of the sequences.
Used to decide whether an additional class needs to be added
to the set of labels :math:`Y = {y_1, y_2, \dots, y_n}`,
where the added class :math:`y_{n+1}`
will represent the unlabeled "background" periods.
Parameters
----------
dataset_df : pandas.DataFrame
A dataframe representing a source dataset of audio signals
or spectrograms, as returned by
:func:`vak.prep.audio_dataset.prep_audio_dataset` or
:func:`vak.prep.spectrogram_dataset.prep_spectrogram_dataset`.
Returns
-------
has_unlabeled : bool
If True, there are annotations in the dataset
that have unlabeled periods.
"""
# NOTE we cast to `bool` because
# np.any returns an instance of <class 'numpy.bool_'>
# and `<class 'numpy.bool_'> is True == False`.
# Not sure if this is numpy version specific
return bool(np.any(where_unlabeled_segments(dataset_df)))