InverseModel

class aixd.mlmodel.architecture.two_stage_model.InverseModel(model: CondAEModel, gen_z_strategy: str = 'encode', sample_around_std: float = 0.1, **kwargs)[source]

Bases: object

Wrapper class for taking a Cond(V)AEModel and fine-tuning the decoder using the encoder as surrogate model. It was observed that the decoder may neglect the conditional features y, which is why a second-stage training was suggested, where the order of the decoder and encoder are swapped, the encoder frozen and the model subsequently trained to reconstruct the correct y.

This class copies the complete functionality of the input model, except for the forward and fit methods, which are redefined for fine-tuning.

Parameters:
  • model (CondAEModel)) – Model whose decoder should be fine-tuned using its (frozen) encoder as surrogate.

  • gen_z_strategy (str, default=’encode’) – Which strategy should be employed for optaining the latent vectors z. One of ‘encode’, ‘sample’ or ‘sample_around’. - if ‘encode’, the latent variables z are generated by the (frozen) encoder - if ‘sample’, the latent variables z are sampled normally - if ‘sample_around’, the latent variables z are generated by the (frozen) encoder and added to random gaussian noise with std sample_around_std.

  • sample_around_std (float, default=0.1) – If `gen_z_strategy`is ‘sample_around’, the random gaussion noise is sampled using this value as standard deviation.

Methods

fit

Launches fine-tuning of the decoder in self.model.

forward

Obtains latent vectors z according to self.gen_z_strategy, then decodes the data to obtain x_hat and finally encodes x_hat using the (frozen) encoder to obtain y_hat.