GlobalSensitivity.calculate

GlobalSensitivity.calculate(data: torch.Tensor, features: str | List[str] | None = None) Dict[str, Dict[str, torch.Tensor | Dict[str, torch.Tensor]]][source]

Calculate the global sensitivities by perfroming a sensitivity analysis of the features with respect to the input data (a batch of data points). For continous data x, the sensitivities are calculated as the gradients of y w.r.t. x via backpropagation through the model. For categorical data x, the gradients are calculated as the finite difference (y(x’)-y(x)), where x’ is a perturbed version of x by changing the value of a categorical variable , calculated for each of the options in the domain of that variable.

Parameters:
  • data (torch.Tensor) – The data point(s) at which the output sensitivity is calculated. Tensor of shape (n_samples, n_input_features). Typically, n_samples > 1.

  • features (str or List[str], optional, default None) – The list of ouptut features for which the sensitivity is calculated. If None, the sensitivity is calculated for all output features of the encoder.

Returns:

sensitivity (Dict[str, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]) – The sensitivity of the output feature(s) with respect to the input.

Examples

>>> global_sensitivity = GlobalSensitivity(model) 
>>> global_sensitivity.calculate(data=datamodule.x_test, features=model.output_ml_dblock.names_list)