CondAEModel.predict

CondAEModel.predict(data: DataModule | Tuple[torch.Tensor, torch.Tensor] | Tuple[ndarray, ndarray], accelerator: str = 'auto', **kwargs) Dict[str, Any][source]

Make predictions using the model.

Parameters:
  • data (Union[DataModule, Tuple[torch.Tensor, torch.Tensor], Tuple[np.array, np.array]]) – A PyTorch DataModule object, or a tuple of two PyTorch Tensors, or a tuple of two numpy arrays, containing data from which to make predictions.

  • accelerator (str, optional, default=”auto”) – Which accelerator should be used (e.g. cpu, gpu, mps, etc.).

  • **kwargs – Additional keyword arguments that can be passed to the Trainer. Default is an empty dictionary.

Returns:

Dict[str, Any] – A dictionary containing the model’s predictions.