CondAEModel.predict_step
- CondAEModel.predict_step(batch: Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) Dict[str, torch.Tensor] [source]
Perform a single prediction step. Depending on the input, the encoder, decoder, or the entire model is used.
- Parameters:
batch (Tuple[torch.Tensor, torch.Tensor]) – A tuple containing the data for a single batch.
batch_idx (int) – The index of the current batch.
- Returns:
Dict[str, torch.Tensor] – A dictionary containing the model’s output tensors. This could include the latent representation ‘z’, the reconstructed ‘y’, and the reconstructed ‘x’.