CondVAEModel.fit

CondVAEModel.fit(datamodule: DataModule, name_run: str | None = '', max_epochs: int = 100, callbacks: pytorch_lightning.Callback | List[pytorch_lightning.Callback] | None = None, loggers: pytorch_lightning.loggers.Logger | Iterable[pytorch_lightning.loggers.Logger] | None = None, accelerator: str = 'auto', flag_early_stop: bool = False, criteria: str = 'train/loss', flag_wandb: bool = False, wandb_entity: str | None = None, **kwargs) None[source]

Train the model on the provided data using PyTorch Lightning’s Trainer.

Parameters:
  • datamodule (pl.LightningDataModule) – The data module object that provides the training, validation, and test data.

  • name_run (str, optional, default=”NoName”) – Name of the current run, used for logging and saving checkpoints. Not used if flag_wandb is True.

  • max_epochs (int, optional, default=100) – The maximum number of epochs to train the model.

  • callbacks (Union[Callback, List[Callback]], optional, default=None) – List of callbacks or a single callback to be used during training.

  • loggers (Union[Logger, Iterable[Logger]], optional, default=None) – List of logger instances or a single logger for logging training progress and metrics.

  • accelerator (str, optional, default=”auto”) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, “auto”).

  • flag_early_stop (bool, optional, default=False) – If True, enable early stopping based on the provided criteria.

  • criteria (str, optional, default=”train{aixd.mlmodel.constants.SEP_LOSSES}loss”) – The criteria used for early stopping.

  • flag_wandb (bool, optional, default=False) – If True, enable logging using Weights & Biases (wandb).

  • wandb_entity (str, optional, default=None) – If flag_wandb is True, the entity (username or team) to which the run will be logged. If None, the default entity is used.

  • **kwargs – Additional keyword arguments that can be passed to the Trainer. Default is an empty dictionary.