GlobalSensitivity.plot

GlobalSensitivity.plot(data: torch.Tensor | Dict[str, torch.Tensor], features: str | List[str], return_sensitivities: bool = False, renderer: str | None = None) Dict[str, Dict[str, torch.Tensor | Dict[str, torch.Tensor]]][source]

Calculate and plot the global sensitivities for a specific or multiple feature in a horizontal box chart. Calculates the local sensitivities for each data point and aggregates them in the form of box plots.

Parameters:
  • data (torch.Tensor) – The data point(s) at which the output sensitivity is calculated. Tensor of shape (n_samples, n_input_features). Typically, n_samples > 1.

  • features (str or List[str], optional, default None) – The list of ouptut features for which the sensitivity is calculated. If None, the sensitivity is calculated for all output features of the encoder.

  • return_sensitivities (bool, default False) – If True, the calculated sensitivities are returned.

  • renderer (str, default None) – The renderer to use for the plot. If None, the default renderer is used. If “browser”, the plot is opened in the browser.

Returns:

sensitivity (Dict[str, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]) – The sensitivity of the output feature(s) with respect to the input.

Examples

>>> global_sensitivity = GlobalSensitivity(model) 
>>> global_sensitivity.plot(data=datamodule.x_test, features=model.output_ml_dblock.names_list, return_sensitivities=False)