Source code for vak.nets.conv_encoder

from __future__ import annotations

import torch
import torch.nn as nn


[docs] class ConvEncoder(nn.Module): """Convolutional encoder, used by Parametric UMAP model. """
[docs] def __init__( self, input_shape: tuple[int], conv1_filters: int = 32, conv2_filters: int = 64, conv_kernel_size: int = 3, conv_stride: int = 2, conv_padding: int = 1, n_features_linear: int = 256, n_components: int = 2, ): """Initialize a ConvEncoder instance. Parameters ---------- input_shape : tuple with 3 elements corresponding to dimensions of spectrogram: (channels, frequency bins, time bins). I.e., we assume input is a spectrogram and treat it like an image, typically with one channel; the rows are frequency bins, and the columns are time bins. conv1_filters : int Number of filters in first convolutional layer. Default is 64. conv2_filters : int Number of filters in second convolutional layer. Default is 128. conv_kernel_size : tuple Size of kernels, i.e. filters, in convolutional layers. Default is 3. conv_padding : int Amount of padding for convolutional layers. Default is 1. n_components : int Number of components of latent space that encoder maps to. Default is 2. """ super().__init__() if len(input_shape) != 3: raise ValueError( "Expected input_shape with length 3, (channels, height, width), " f"but input shape was length {len(input_shape)}. " f"Input shape was: {input_shape}" ) self.input_shape = input_shape self.num_input_channels = input_shape[0] self.conv = nn.Sequential( nn.Conv2d( in_channels=self.num_input_channels, out_channels=conv1_filters, kernel_size=conv_kernel_size, stride=conv_stride, padding=conv_padding, ), nn.MaxPool2d(2, 2), nn.Conv2d( in_channels=conv1_filters, out_channels=conv2_filters, kernel_size=conv_kernel_size, stride=conv_stride, padding=conv_padding, ), nn.MaxPool2d(2, 2), nn.Flatten(), ) mock_input = torch.rand((1, *input_shape)) mock_conv_out = self.conv(mock_input) in_features = mock_conv_out.shape[-1] self.encoder = nn.Sequential( nn.Linear(in_features, n_features_linear), nn.ReLU(), nn.Linear(n_features_linear, n_features_linear), nn.ReLU(), nn.Linear(n_features_linear, n_components), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.encoder(x) return x