CondVAEModel

class aixd.mlmodel.architecture.cond_vae_model.CondVAEModel(*args: Any, **kwargs: Any)[source]

Bases: CondAEModel

Class representing a Conditional Variational Autoencoder model.

Parameters:
  • input_ml_dblock (InputML) – A input ml data block defining the input heads of the model.

  • output_ml_dblock (OutputML) – A output ml data block defining the output heads of the model.

  • layer_widths (List[int]) – List of integers specifying the number of units in each hidden layer of the autoencoder’s encoder and decoder (i.e., the “core” of the autoencoder). The first element of the list corresponds to the number of units in the first hidden layer of the encoder, the last element corresponds to the number of units in the last hidden layer of the decoder, and the elements in between correspond to the number of units in each hidden layer of the autoencoder in the order they appear (encoder followed by decoder).

  • latent_dim (int) – Integer specifying the number of units in the latent (i.e., encoded) representation of the data.

  • heads_layer_widths (Dict[str, List[int]], optional, default={}) – Dictionary specifying the number of units in the “head” layers that are added to the autoencoder. The keys of the dictionary are the names of the features, the values are a sequence of integers specifying the number of units in each hidden layer of the head.

  • custom_losses (Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional, default=None) – Dictionary containing custom losses to be computed on the outputs.

  • loss_weights (Dict[str, float], optional, default=None) – Dictionary containing the weights with which each loss term should be multiplied before being added to the total loss used for backpropagation, including custom losses.

  • activation (Union[torch.nn.Module, str], optional, default=”leaky_relu”) – Activation function to be used in the latent layers of the autoencoder.

  • optimizer (torch.optim.Optimizer, optional, default=None) – Optimizer to be used for updating the model’s weights.

  • pass_y_to_encoder (bool, optional, default=False) – Whether to pass the conditional features ‘y’ to the autoencoder’s encoder (vanilla cVAE formulation) or not. If ‘True’, the encoder maps from ‘x’ to ‘z’ and is solely used for finding the latent vector needed to reconstruct ‘x’ given ‘y’. In ‘False’, the encoder represents a surrogate model mapping from ‘x’ to ‘y’ as well as a latent vector ‘z’.

  • name (str, optional, default=”CondVAEModel”) – Name of the model.

  • save_dir (str, optional, default=None) – Directory where the model related files will be saved, such as the models’s checkpoint and logs.

  • name_proj (str, optional, default=None) – Name of the project.

  • **kwargs (dict) – Additional arguments passed to pytorch_lightning.core.module.LightningModule.

Methods

decode

Decode the latent representation into the original data space.

Inherited Methods

configure_optimizers

Configure the optimizers for the model.

encode

Encode the input data into a latent representation.

evaluate

Evaluate the model on the validation data.

fit

Train the model on the provided data using PyTorch Lightning's Trainer.

forward

Forward pass of the model.

forward_evaluation

Receives some values of inputML, and returns the corresponding outputML, as predicted by the model.

from_datamodule

Create a model from a data module.

inverse_design

load_model_from_checkpoint

Load a model from a checkpoint file.

on_load_checkpoint

Load the extra hyperparameters from the model checkpoint.

on_save_checkpoint

Save the extra hyperparameters to the model checkpoint.

predict

Make predictions using the model.

save_extra_parameters

Extra parameters that are saved as part of the model checkpoint.

summary

Prints a summary of the encoder and decoder, including the number of parameters, the layers, their names, and the dimensionality.

test

Evaluate the model on the test data.

test_step

Perform a single test step.

training_step

Perform a single training step.

validate

Evaluate the model on the validation data.

validation_step

Perform a single validation step.