chatter.trainer#
Classes
|
Trainer class for the unified variational autoencoder implementation. |
- class chatter.trainer.Trainer(config)[source]#
Trainer class for the unified variational autoencoder implementation.
This class manages model initialization, training, evaluation, and feature extraction for the shared
Encoder, which can be configured as either a convolutional or vector-based VAE. It is designed to operate on a single device (CPU, CUDA GPU, or Apple MPS).- config#
Configuration dictionary containing model and training parameters.
- Type:
dict
- device#
Computation device used for training and inference.
- Type:
torch.device
- ae_type#
Type of autoencoder architecture (‘convolutional’ or ‘vector’).
- Type:
str
- ae_model#
Unified variational autoencoder model encapsulating encoder and decoder components.
- Type:
- __init__(config)[source]#
Initialize the Trainer with a configuration dictionary.
- Parameters:
config (dict) – Configuration dictionary containing model and training parameters.
- classmethod from_trained(config, model_dir)[source]#
Create a Trainer instance and load a pre-trained model.
This class method instantiates a new Trainer with the provided configuration and immediately loads model weights from the specified path. It enables direct use of methods such as ‘extract_and_save_features’ or ‘plot_reconstructions’ without retraining.
- Parameters:
config (dict) – Configuration dictionary for the model and training.
model_dir (str or Path) – Path to the saved model directory.
- Returns:
An instance of the Trainer class with the model weights loaded.
- Return type:
- train_ae(unit_df, h5_path, model_dir, subset=None)[source]#
Train the variational autoencoder using an HDF5 dataset.
This method creates a SpectrogramDataset from an HDF5 file, constructs a DataLoader, and runs a standard training loop for a configured number of epochs. It optionally trains on a random subset of units, and saves the trained model and loss history to disk.
- Parameters:
unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.
h5_path (str or Path) – Path to the HDF5 file containing spectrograms.
model_dir (str or Path) – Directory in which to save the trained model and loss history CSV.
subset (float, optional) – Proportion of units to use for training. Must be in the range (0, 1) if provided. If None or outside this range, the full dataset is used. The default is None.
- Returns:
This method trains a model and writes the results to disk but does not return a value.
- Return type:
None
- load_ae(model_dir)[source]#
Load pre-trained weights into the variational autoencoder model.
- Parameters:
model_dir (str or Path) – Path to the saved model directory.
- Returns:
This method loads model weights and sets the model to evaluation mode. It prints status messages and does not return a value.
- Return type:
None
- plot_reconstructions(unit_df, h5_path, num_examples=8)[source]#
Plot a side-by-side comparison of original and reconstructed spectrograms.
This method samples a set of unit spectrograms from the HDF5 dataset, passes them through the autoencoder, and visualizes the original and reconstructed spectrograms for qualitative inspection of model performance.
- Parameters:
unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.
h5_path (str or Path) – Path to the HDF5 file containing spectrograms.
num_examples (int, optional) – Number of examples to plot. If the dataset contains fewer than ‘num_examples’ units, all available units are plotted. The default is 8.
- Returns:
This method displays a matplotlib figure and does not return a value.
- Return type:
None
- extract_and_save_features(unit_df, h5_path, model_dir, output_csv_path)[source]#
Extract latent features for all units using the HDF5 file and save them.
This method loads a trained autoencoder model, iterates through all spectrograms in the HDF5 dataset, encodes them into latent features, and writes a combined DataFrame containing both metadata and latent features to CSV.
- Parameters:
unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.
h5_path (str or Path) – Path to the HDF5 file containing spectrograms.
model_dir (str or Path) – Path to the saved model directory.
output_csv_path (str or Path) – Path to the CSV file in which to store metadata and latent features.
- Returns:
DataFrame containing metadata and extracted features for all units, or None if the model could not be loaded successfully.
- Return type:
pd.DataFrame or None
- extract_and_save_comp_viz_features(unit_df, h5_path, output_csv_path, checkpoint=None)[source]#
Extract features using a Hugging Face computer vision model for all units and save them.
This method loads a pretrained computer vision model, iterates through all spectrograms in the HDF5 dataset, encodes them into fixed-length embeddings, and writes a combined DataFrame containing both metadata and features to CSV. The features are stored in columns named ‘cv_feat_{i}’.
- Parameters:
unit_df (pd.DataFrame) – DataFrame containing unit metadata with a column ‘h5_index’ referring to indices in the HDF5 ‘spectrograms’ dataset.
h5_path (str or Path) – Path to the HDF5 file containing spectrograms.
output_csv_path (str or Path) – Path to the CSV file in which to store metadata and features.
checkpoint (str, optional) – Hugging Face model checkpoint name. If None, a default checkpoint from self.config[‘vision_checkpoint’] is used, or ‘facebook/dinov3-vitb16-pretrain-lvd1689m’ if that key is not present. The default is None.
- Returns:
DataFrame containing metadata and features for all units, or None if the model could not be loaded successfully.
- Return type:
pd.DataFrame or None