CondVAEModel.fit
- CondVAEModel.fit(datamodule: DataModule, name_run: str | None = '', max_epochs: int = 100, callbacks: pytorch_lightning.Callback | List[pytorch_lightning.Callback] | None = None, loggers: pytorch_lightning.loggers.Logger | Iterable[pytorch_lightning.loggers.Logger] | None = None, accelerator: str = 'auto', flag_early_stop: bool = False, criteria: str = 'train/loss', flag_wandb: bool = False, wandb_entity: str | None = None, **kwargs) None [source]
Train the model on the provided data using PyTorch Lightning’s Trainer.
- Parameters:
datamodule (pl.LightningDataModule) – The data module object that provides the training, validation, and test data.
name_run (str, optional, default=”NoName”) – Name of the current run, used for logging and saving checkpoints. Not used if flag_wandb is True.
max_epochs (int, optional, default=100) – The maximum number of epochs to train the model.
callbacks (Union[Callback, List[Callback]], optional, default=None) – List of callbacks or a single callback to be used during training.
loggers (Union[Logger, Iterable[Logger]], optional, default=None) – List of logger instances or a single logger for logging training progress and metrics.
accelerator (str, optional, default=”auto”) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, “auto”).
flag_early_stop (bool, optional, default=False) – If True, enable early stopping based on the provided criteria.
criteria (str, optional, default=”train{aixd.mlmodel.constants.SEP_LOSSES}loss”) – The criteria used for early stopping.
flag_wandb (bool, optional, default=False) – If True, enable logging using Weights & Biases (wandb).
wandb_entity (str, optional, default=None) – If flag_wandb is True, the entity (username or team) to which the run will be logged. If None, the default entity is used.
**kwargs – Additional keyword arguments that can be passed to the Trainer. Default is an empty dictionary.