GlobalSensitivity.plot
- GlobalSensitivity.plot(data: torch.Tensor | Dict[str, torch.Tensor], features: str | List[str], return_sensitivities: bool = False, renderer: str | None = None) Dict[str, Dict[str, torch.Tensor | Dict[str, torch.Tensor]]] [source]
Calculate and plot the global sensitivities for a specific or multiple feature in a horizontal box chart. Calculates the local sensitivities for each data point and aggregates them in the form of box plots.
- Parameters:
data (torch.Tensor) – The data point(s) at which the output sensitivity is calculated. Tensor of shape (n_samples, n_input_features). Typically, n_samples > 1.
features (str or List[str], optional, default None) – The list of ouptut features for which the sensitivity is calculated. If None, the sensitivity is calculated for all output features of the encoder.
return_sensitivities (bool, default False) – If True, the calculated sensitivities are returned.
renderer (str, default None) – The renderer to use for the plot. If None, the default renderer is used. If “browser”, the plot is opened in the browser.
- Returns:
sensitivity (Dict[str, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]) – The sensitivity of the output feature(s) with respect to the input.
Examples
>>> global_sensitivity = GlobalSensitivity(model) >>> global_sensitivity.plot(data=datamodule.x_test, features=model.output_ml_dblock.names_list, return_sensitivities=False)