CondVAEModel

class aixd.mlmodel.architecture.cond_vae_model.CondVAEModel(*args: Any, **kwargs: Any)[source]

Bases: CondAEModel

Class representing a Conditional Variational Autoencoder model.

Parameters:
  • input_ml_dblock (InputML) – A input ml data block defining the input heads of the model.

  • output_ml_dblock (OutputML) – A output ml data block defining the output heads of the model.

  • layer_widths (List[int]) – List of integers specifying the number of units in each hidden layer of the autoencoder’s encoder and decoder (i.e., the “core” of the autoencoder). The first element of the list corresponds to the number of units in the first hidden layer of the encoder, the last element corresponds to the number of units in the last hidden layer of the decoder, and the elements in between correspond to the number of units in each hidden layer of the autoencoder in the order they appear (encoder followed by decoder).

  • latent_dim (int) – Integer specifying the number of units in the latent (i.e., encoded) representation of the data.

  • heads_layer_widths (Dict[str, List[int]], optional, default={}) – Dictionary specifying the number of units in the “head” layers that are added to the autoencoder. The keys of the dictionary are the names of the features, the values are a sequence of integers specifying the number of units in each hidden layer of the head.

  • custom_losses (Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional, default=None) – Dictionary containing custom losses to be computed on the outputs.

  • loss_weights (Dict[str, float], optional, default=None) – Dictionary containing the weights for each loss term used in backpropagation. Valid keys are: ‘x’, ‘y’, ‘decorrelation’, ‘kl’, any custom loss name from custom_losses.

  • activation (Union[torch.nn.Module, str], optional, default=”leaky_relu”) – Activation function to be used in the latent layers of the autoencoder.

  • optimizer (torch.optim.Optimizer, optional, default=None) – Optimizer to be used for updating the model’s weights.

  • name (str, optional, default=”CondVAEModel”) – Name of the model.

  • save_dir (str, optional, default=None) – Directory where the model related files will be saved, such as the models’s checkpoint and logs.

  • name_proj (str, optional, default=None) – Name of the project.

  • **kwargs (dict) – Additional arguments passed to pytorch_lightning.core.module.LightningModule.