chatter.trainer#

Classes

Trainer(config)

Trainer class for the unified variational autoencoder implementation.

class chatter.trainer.Trainer(config)[source]#

Trainer class for the unified variational autoencoder implementation.

This class manages model initialization, training, evaluation, and feature extraction for the shared Encoder, which can be configured as either a convolutional or vector-based VAE. It is designed to operate on a single device (CPU, CUDA GPU, or Apple MPS).

config#

Configuration dictionary containing model and training parameters.

Type:

dict

device#

Computation device used for training and inference.

Type:

torch.device

ae_type#

Type of autoencoder architecture (‘convolutional’ or ‘vector’).

Type:

str

ae_model#

Unified variational autoencoder model encapsulating encoder and decoder components.

Type:

Encoder

__init__(config)[source]#

Initialize the Trainer with a configuration dictionary.

Parameters:

config (dict) – Configuration dictionary containing model and training parameters.

classmethod from_trained(config, model_dir)[source]#

Create a Trainer instance and load a pre-trained model.

This class method instantiates a new Trainer with the provided configuration and immediately loads model weights from the specified path. It enables direct use of methods such as ‘extract_and_save_features’ or ‘plot_reconstructions’ without retraining.

Parameters:
  • config (dict) – Configuration dictionary for the model and training.

  • model_dir (str or Path) – Path to the saved model directory.

Returns:

An instance of the Trainer class with the model weights loaded.

Return type:

Trainer

train_ae(unit_df, h5_path, model_dir, subset=None)[source]#

Train the variational autoencoder using an HDF5 dataset.

This method creates a SpectrogramDataset from an HDF5 file, constructs a DataLoader, and runs a standard training loop for a configured number of epochs. It optionally trains on a random subset of units, and saves the trained model and loss history to disk.

Parameters:
  • unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.

  • h5_path (str or Path) – Path to the HDF5 file containing spectrograms.

  • model_dir (str or Path) – Directory in which to save the trained model and loss history CSV.

  • subset (float, optional) – Proportion of units to use for training. Must be in the range (0, 1) if provided. If None or outside this range, the full dataset is used. The default is None.

Returns:

This method trains a model and writes the results to disk but does not return a value.

Return type:

None

load_ae(model_dir)[source]#

Load pre-trained weights into the variational autoencoder model.

Parameters:

model_dir (str or Path) – Path to the saved model directory.

Returns:

This method loads model weights and sets the model to evaluation mode. It prints status messages and does not return a value.

Return type:

None

plot_reconstructions(unit_df, h5_path, num_examples=8)[source]#

Plot a side-by-side comparison of original and reconstructed spectrograms.

This method samples a set of unit spectrograms from the HDF5 dataset, passes them through the autoencoder, and visualizes the original and reconstructed spectrograms for qualitative inspection of model performance.

Parameters:
  • unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.

  • h5_path (str or Path) – Path to the HDF5 file containing spectrograms.

  • num_examples (int, optional) – Number of examples to plot. If the dataset contains fewer than ‘num_examples’ units, all available units are plotted. The default is 8.

Returns:

This method displays a matplotlib figure and does not return a value.

Return type:

None

extract_and_save_features(unit_df, h5_path, model_dir, output_csv_path)[source]#

Extract latent features for all units using the HDF5 file and save them.

This method loads a trained autoencoder model, iterates through all spectrograms in the HDF5 dataset, encodes them into latent features, and writes a combined DataFrame containing both metadata and latent features to CSV.

Parameters:
  • unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.

  • h5_path (str or Path) – Path to the HDF5 file containing spectrograms.

  • model_dir (str or Path) – Path to the saved model directory.

  • output_csv_path (str or Path) – Path to the CSV file in which to store metadata and latent features.

Returns:

DataFrame containing metadata and extracted features for all units, or None if the model could not be loaded successfully.

Return type:

pd.DataFrame or None

extract_and_save_comp_viz_features(unit_df, h5_path, output_csv_path, checkpoint=None)[source]#

Extract features using a Hugging Face computer vision model for all units and save them.

This method loads a pretrained computer vision model, iterates through all spectrograms in the HDF5 dataset, encodes them into fixed-length embeddings, and writes a combined DataFrame containing both metadata and features to CSV. The features are stored in columns named ‘cv_feat_{i}’.

Parameters:
  • unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.

  • h5_path (str or Path) – Path to the HDF5 file containing spectrograms.

  • output_csv_path (str or Path) – Path to the CSV file in which to store metadata and features.

  • checkpoint (str, optional) – Hugging Face model checkpoint name. If None, a default checkpoint from self.config[‘vision_checkpoint’] is used, or ‘facebook/dinov3-vitb16-pretrain-lvd1689m’ if that key is not present. The default is None.

Returns:

DataFrame containing metadata and features for all units, or None if the model could not be loaded successfully.

Return type:

pd.DataFrame or None