quantus.metrics.robustness.max_sensitivity module

This module contains the implementation of the Max-Sensitivity metric.

class quantus.metrics.robustness.max_sensitivity.MaxSensitivity(similarity_func: Callable | None = None, norm_numerator: Callable | None = None, norm_denominator: Callable | None = None, nr_samples: int = 200, abs: bool = False, normalise: bool = False, normalise_func: Callable[[ndarray], ndarray] | None = None, normalise_func_kwargs: Dict[str, Any] | None = None, perturb_func: Callable | None = None, lower_bound: float = 0.2, upper_bound: None | float = None, perturb_func_kwargs: Dict[str, Any] | None = None, return_aggregate: bool = False, aggregate_func: Callable | None = None, default_plot_func: Callable | None = None, disable_warnings: bool = False, display_progressbar: bool = False, return_nan_when_prediction_changes: bool = False, **kwargs)

Bases: Metric[List[float]]

Implementation of Max-Sensitivity by Yeh at el., 2019.

Using Monte Carlo sampling-based approximation while measuring how explanations change under slight perturbation - the maximum sensitivity is captured.

References:

1) Chih-Kuan Yeh et al. “On the (in) fidelity and sensitivity for explanations.” NeurIPS (2019): 10965-10976. 2) Umang Bhatt et al.: “Evaluating and aggregating feature-based model explanations.” IJCAI (2020): 3016-3022.

Attributes:
  • _name: The name of the metric.

  • _data_applicability: The data types that the metric implementation currently supports.

  • _models: The model types that this metric can work with.

  • score_direction: How to interpret the scores, whether higher/ lower values are considered better.

  • evaluation_category: What property/ explanation quality that this metric measures.

Attributes:
disable_warnings

A helper to avoid polluting test outputs with warnings.

display_progressbar

A helper to avoid polluting test outputs with tqdm progress bars.

get_params

List parameters of metric.

Methods

__call__(model, x_batch, y_batch[, a_batch, ...])

This implementation represents the main logic of the metric and makes the class object callable.

batch_preprocess(data_batch)

If data_batch has no a_batch, will compute explanations.

custom_batch_preprocess(*, model, x_batch, ...)

Implement this method if you need custom preprocessing of data or simply for creating/initialising additional attributes or assertions before a data_batch can be evaluated.

custom_postprocess(*, model, x_batch, ...)

Implement this method if you need custom postprocessing of results or additional attributes.

custom_preprocess(**kwargs)

Implementation of custom_preprocess_batch.

evaluate_batch(model, x_batch, y_batch, ...)

Evaluates model and attributes on a single data batch and returns the batched evaluation result.

explain_batch(model, x_batch, y_batch)

Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: - call model.shape_input (if ModelInterface instance was provided) - unwrap model (if ModelInterface instance was provided) - call explain_func - expand attribution channel - (optionally) normalise a_batch - (optionally) take np.abs of a_batch.

general_preprocess(model, x_batch, y_batch, ...)

Prepares all necessary variables for evaluation.

generate_batches(data, batch_size)

Creates iterator to iterate over all batched instances in data dictionary.

interpret_scores()

Get an interpretation of the scores.

plot([plot_func, show, path_to_save])

Basic plotting functionality for Metric class.

__call__(model, x_batch: ndarray, y_batch: ndarray, a_batch: ndarray | None = None, s_batch: ndarray | None = None, channel_first: bool | None = None, explain_func: Callable | None = None, explain_func_kwargs: Dict | None = None, model_predict_kwargs: Dict | None = None, softmax: bool | None = False, device: str | None = None, batch_size: int = 64, **kwargs) List[float]

This implementation represents the main logic of the metric and makes the class object callable. It completes instance-wise evaluation of explanations (a_batch) with respect to input data (x_batch), output labels (y_batch) and a torch or tensorflow model (model).

Calls general_preprocess() with all relevant arguments, calls () on each instance, and saves results to evaluation_scores. Calls custom_postprocess() afterwards. Finally returns evaluation_scores.

Parameters:
model: torch.nn.Module, tf.keras.Model

A torch or tensorflow model that is subject to explanation.

x_batch: np.ndarray

A np.ndarray which contains the input data that are explained.

y_batch: np.ndarray

A np.ndarray which contains the output labels that are explained.

a_batch: np.ndarray, optional

A np.ndarray which contains pre-computed attributions i.e., explanations.

s_batch: np.ndarray, optional

A np.ndarray which contains segmentation masks that matches the input.

channel_first: boolean, optional

Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.

explain_func: callable

Callable generating attributions.

explain_func_kwargs: dict, optional

Keyword arguments to be passed to explain_func on call.

model_predict_kwargs: dict, optional

Keyword arguments to be passed to the model’s predict method.

softmax: boolean

Indicates whether to use softmax probabilities or logits in model prediction. This is used for this __call__ only and won’t be saved as attribute. If None, self.softmax is used.

device: string

Indicated the device on which a torch.Tensor is or will be allocated: “cpu” or “gpu”.

kwargs: optional

Keyword arguments.

Returns:
evaluation_scores: list

a list of Any with the evaluation scores of the concerned batch.

Examples:

# Minimal imports. >> import quantus >> from quantus import LeNet >> import torch

# Enable GPU. >> device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

# Load a pre-trained LeNet classification model (architecture at quantus/helpers/models). >> model = LeNet() >> model.load_state_dict(torch.load(“tutorials/assets/pytests/mnist_model”))

# Load MNIST datasets and make loaders. >> test_set = torchvision.datasets.MNIST(root=’./sample_data’, download=True) >> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

# Load a batch of inputs and outputs to use for XAI evaluation. >> x_batch, y_batch = iter(test_loader).next() >> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Generate Saliency attributions of the test set batch of the test set. >> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1) >> a_batch_saliency = a_batch_saliency.cpu().numpy()

# Initialise the metric and evaluate explanations by calling the metric instance. >> metric = Metric(abs=True, normalise=False) >> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)

__init__(similarity_func: Callable | None = None, norm_numerator: Callable | None = None, norm_denominator: Callable | None = None, nr_samples: int = 200, abs: bool = False, normalise: bool = False, normalise_func: Callable[[ndarray], ndarray] | None = None, normalise_func_kwargs: Dict[str, Any] | None = None, perturb_func: Callable | None = None, lower_bound: float = 0.2, upper_bound: None | float = None, perturb_func_kwargs: Dict[str, Any] | None = None, return_aggregate: bool = False, aggregate_func: Callable | None = None, default_plot_func: Callable | None = None, disable_warnings: bool = False, display_progressbar: bool = False, return_nan_when_prediction_changes: bool = False, **kwargs)
Parameters:
similarity_func: callable

Similarity function applied to compare input and perturbed input. If None, the default value is used, default=difference.

norm_numerator: callable

Function for norm calculations on the numerator. If None, the default value is used, default=fro_norm

norm_denominator: callable

Function for norm calculations on the denominator. If None, the default value is used, default=fro_norm

nr_samples: integer

The number of samples iterated, default=200.

normalise: boolean

Indicates whether normalise operation is applied on the attribution, default=True.

normalise_func: callable

Attribution normalisation function applied in case normalise=True. If normalise_func=None, the default value is used, default=normalise_by_max.

normalise_func_kwargs: dict

Keyword arguments to be passed to normalise_func on call, default={}.

perturb_func: callable

Input perturbation function. If None, the default value is used, default=gaussian_noise.

lower_bound: float

The lower bound of the noise.

upper_bound: float, optional

The upper bound of the noise.

perturb_func_kwargs: dict

Keyword arguments to be passed to perturb_func, default={}.

return_aggregate: boolean

Indicates if an aggregated score should be computed over all instances.

aggregate_func: callable

Callable that aggregates the scores given an evaluation call.

default_plot_func: callable

Callable that plots the metrics result.

disable_warnings: boolean

Indicates whether the warnings are printed, default=False.

display_progressbar: boolean

Indicates whether a tqdm-progress-bar is printed, default=False.

return_nan_when_prediction_changes: boolean

When set to true, the metric will be evaluated to NaN if the prediction changes after the perturbation is applied.

kwargs: optional

Keyword arguments.

custom_preprocess(**kwargs) None

Implementation of custom_preprocess_batch.

Parameters:
kwargs:

Unused.

Returns:
None
data_applicability: ClassVar[Set[DataType]] = {DataType.IMAGE, DataType.TABULAR, DataType.TIMESERIES}
evaluate_batch(model: ModelInterface, x_batch: ndarray, y_batch: ndarray, a_batch: ndarray, **kwargs) ndarray

Evaluates model and attributes on a single data batch and returns the batched evaluation result.

Parameters:
model: ModelInterface

A ModelInteface that is subject to explanation.

x_batch: np.ndarray

The input to be evaluated on an instance-basis.

y_batch: np.ndarray

The output to be evaluated on an instance-basis.

a_batch: np.ndarray

The explanation to be evaluated on an instance-basis.

kwargs:

Unused.

Returns:
scores_batch: np.ndarray

The batched evaluation results.

evaluation_category: ClassVar[EvaluationCategory] = 'Robustness'
model_applicability: ClassVar[Set[ModelType]] = {ModelType.TF, ModelType.TORCH}
name: ClassVar[str] = 'Max-Sensitivity'
score_direction: ClassVar[ScoreDirection] = 'lower'