LocalSensitivity.plot
- LocalSensitivity.plot(data: torch.Tensor, features: str | List[str] | None = None, return_sensitivities: bool = False, renderer: str | None = None) Dict[str, Dict[str, torch.Tensor | Dict[str, torch.Tensor]]] [source]
Calculate and plot the local sensitivities for specified feature in a horizontal bar chart.
- Parameters:
data (torch.Tensor) – The data point(s) at which the output sensitivity is calculated. Tensor of shape (n_samples, n_input_features). If n_samples > 1, the senstivity is aggregated and calculated with respect to the mean of the output features.
features (str or List[str], optional, default None) – The list of ouptut features for which the sensitivity is calculated. If None, the sensitivity is calculated for all output features of the encoder.
return_sensitivities (bool, False) – A boolean determining if the calculated sensitivities are returned.
renderer (str, default None) – The renderer to use for the plot. If None, the default renderer is used. If “browser”, the plot is opened in the browser.
- Returns:
sensitivity (Dict[str, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]) – The dictionary of local sensitivities.
See also
calculate
The method that calculates the local sensitivities for more details.
Examples
>>> local_sensitivity = LocalSensitivity(model) >>> local_sensitivity.plot(data=datamodule.x_test[0:1,:], features=model.output_ml_dblock.names_list, return_sensitivities=False)