InverseModel.fit
- InverseModel.fit(datamodule: pytorch_lightning.LightningDataModule, max_epochs: int = 100, callbacks: list | None = None, loggers: list | None = None, accelerator: str = 'auto', flag_wandb=False, **kwargs) None [source]
Launches fine-tuning of the decoder in self.model. Freezes the encoder prior to calling self.model.fit() and unfreezes it upon completion.
- Parameters:
data_module (pl.LightningDataModule) – DataModule object that provides the training, validation, and test data.
max_epochs (int, optional, default=100) – The maximum number of epochs to train for.
callbacks (list, optional, default=None) – A list of PyTorch Lightning Callback objects to use during training.
loggers (list, optional, default=None) – A list of PyTorch Lightning Logger objects to use during training.
accelerator (str, optional, default=’auto’) – Which accelerator should be used (e.g. cpu, gpu, mps, etc.)