CondAEModel.predict
- CondAEModel.predict(data: DataModule | torch.utils.data.DataLoader | tuple, return_untransformed: bool = False, return_postprocessed: bool = True, accelerator: str = 'cpu', enable_progress_bar: bool = False, lightning_logger_level: int = 30, disable_user_warnings: bool = False, **kwargs) Dict[str, torch.Tensor] | Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]][source]
- Make predictions using the model. - Parameters:
- data (Union[DataModule, DataLoader, tuple]) – A DataModule object, a PyTorch DataLoader object, or a tuple of two (or three) PyTorch Tensors or numpy arrays, containing data from which to make predictions. 
- return_untransformed (bool, optional, default=False) – If True, the predictions are additionally returned in the original space, by applying the inverse transformation. 
- return_postprocessed (bool, optional, default=True) – If True, the model predictions are returned in post-processed. For accessing the raw model outputs, e.g., class logits/probabilities for categorical data set this flag to False. 
- accelerator (str, optional, default=”cpu”) – Which accelerator should be used (e.g. cpu, gpu, mps, etc.). 
- enable_progress_bar (bool, optional, default=False) – If True, enable the progress bar. 
- lightning_logger_level (int, optional, default=logging.WARNING) – The logging level for PyTorch Lightning. 
- disable_user_warnings (bool, optional, default=False) – If True, disable user warnings. 
- **kwargs – Additional keyword arguments that can be passed to the Trainer. Default is an empty dictionary. 
 
- Returns:
- Union[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]] – A dictionary containing the model’s output tensors. If return_untransformed is True, a tuple of the predicted output data in the transformed and original space is returned. Furthermore, if return_postprocessed=False, the predicted output data in the transformed space is returned without post-processing.