CondAEModel.predict_step

CondAEModel.predict_step(batch: Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) Dict[str, torch.Tensor][source]

Perform a single prediction step. Depending on the input, the encoder, decoder, or the entire model is used.

Parameters:
  • batch (Tuple[torch.Tensor, torch.Tensor]) – A tuple containing the data for a single batch.

  • batch_idx (int) – The index of the current batch.

Returns:

Dict[str, torch.Tensor] – A dictionary containing the model’s output tensors. This could include the latent representation ‘z’, the reconstructed ‘y’, and the reconstructed ‘x’.