quantus.metrics.robustness.relative_input_stability module
- final class quantus.metrics.robustness.relative_input_stability.RelativeInputStability(nr_samples: int = 200, abs: bool = False, normalise: bool = False, normalise_func: Callable[[np.ndarray], np.ndarray] | None = None, normalise_func_kwargs: Dict[str, ...] | None = None, perturb_func: Callable | None = None, perturb_func_kwargs: Dict[str, ...] | None = None, return_aggregate: bool = False, aggregate_func: Callable[[np.ndarray], np.float] | None = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-06, default_plot_func: Callable | None = None, return_nan_when_prediction_changes: bool = True, **kwargs)
Bases:
Metric[List[float]]Relative Input Stability leverages the stability of an explanation with respect to the change in the input data.
RIS(x, x’, e_x, e_x’) = max frac{||frac{e_x - e_{x’}}{e_x}||_p} {max (||frac{x - x’}{x}||_p, epsilon_{min})}
- References:
1) Chirag Agarwal, et. al., 2022. “Rethinking stability for attribution based explanations.”, https://arxiv.org/abs/2203.06877
- Attributes:
_name: The name of the metric.
_data_applicability: The data types that the metric implementation currently supports.
_models: The model types that this metric can work with.
score_direction: How to interpret the scores, whether higher/ lower values are considered better.
evaluation_category: What property/ explanation quality that this metric measures.
- Attributes:
disable_warningsA helper to avoid polluting test outputs with warnings.
display_progressbarA helper to avoid polluting test outputs with tqdm progress bars.
get_paramsList parameters of metric.
Methods
__call__(model, x_batch, y_batch[, ...])For each image x:
batch_preprocess(data_batch)If data_batch has no a_batch, will compute explanations.
custom_batch_preprocess(*, model, x_batch, ...)Implement this method if you need custom preprocessing of data or simply for creating/initialising additional attributes or assertions before a data_batch can be evaluated.
custom_postprocess(*, model, x_batch, ...)Implement this method if you need custom postprocessing of results or additional attributes.
custom_preprocess(*, model, x_batch, ...)Implement this method if you need custom preprocessing of data, model alteration or simply for creating/initialising additional attributes or assertions.
evaluate_batch(model, x_batch, y_batch, ...)explain_batch(model, x_batch, y_batch)Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: - call model.shape_input (if ModelInterface instance was provided) - unwrap model (if ModelInterface instance was provided) - call explain_func - expand attribution channel - (optionally) normalise a_batch - (optionally) take np.abs of a_batch.
general_preprocess(model, x_batch, y_batch, ...)Prepares all necessary variables for evaluation.
generate_batches(data, batch_size)Creates iterator to iterate over all batched instances in data dictionary.
interpret_scores()Get an interpretation of the scores.
plot([plot_func, show, path_to_save])Basic plotting functionality for Metric class.
relative_input_stability_objective(x, xs, ...)Computes relative input stabilities maximization objective as defined here https://arxiv.org/pdf/2203.06877.pdf by the authors.
- __call__(model: tf.keras.Model | torch.nn.Module, x_batch: np.ndarray, y_batch: np.ndarray, model_predict_kwargs: Dict[str, ...] | None = None, explain_func: Callable | None = None, explain_func_kwargs: Dict[str, ...] | None = None, a_batch: np.ndarray | None = None, device: str | None = None, softmax: bool = False, channel_first: bool = True, batch_size: int = 64, **kwargs) List[float]
- For each image x:
Generate num_perturbations perturbed xs in the neighborhood of x.
Compute explanations e_x and e_xs.
Compute relative input stability objective, find max value with respect to xs.
In practise we just use max over a finite xs_batch.
- Parameters:
- model: tf.keras.Model, torch.nn.Module
A torch or tensorflow model that is subject to explanation.
- x_batch: np.ndarray
4D tensor representing batch of input images
- y_batch: np.ndarray
1D tensor, representing predicted labels for the x_batch.
- model_predict_kwargs: dict, optional
Keyword arguments to be passed to the model’s predict method.
- explain_func: callable, optional
Function used to generate explanations.
- explain_func_kwargs: dict, optional
Keyword arguments to be passed to explain_func on call.
- a_batch: np.ndarray, optional
4D tensor with pre-computed explanations for the x_batch.
- device: str, optional
Device on which torch should perform computations.
- softmax: boolean, optional
Indicates whether to use softmax probabilities or logits in model prediction. This is used for this __call__ only and won’t be saved as attribute. If None, self.softmax is used.
- channel_first: boolean, optional
Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.
- batch_size: int
The batch size to be used.
- kwargs:
not used, deprecated
- Returns
- ——-
- relative input stability: float, np.ndarray
float in case return_aggregate=True, otherwise np.ndarray of floats
- __init__(nr_samples: int = 200, abs: bool = False, normalise: bool = False, normalise_func: Callable[[np.ndarray], np.ndarray] | None = None, normalise_func_kwargs: Dict[str, ...] | None = None, perturb_func: Callable | None = None, perturb_func_kwargs: Dict[str, ...] | None = None, return_aggregate: bool = False, aggregate_func: Callable[[np.ndarray], np.float] | None = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-06, default_plot_func: Callable | None = None, return_nan_when_prediction_changes: bool = True, **kwargs)
- Parameters:
- nr_samples: int
The number of samples iterated, default=200.
- abs: boolean
Indicates whether absolute operation is applied on the attribution.
- normalise: boolean
Flag stating if the attributions should be normalised
- normalise_func: callable
Attribution normalisation function applied in case normalise=True.
- normalise_func_kwargs: dict
Keyword arguments to be passed to normalise_func on call, default={}.
- perturb_func: callable
Input perturbation function. If None, the default value is used, default=gaussian_noise.
- perturb_func_kwargs: dict
Keyword arguments to be passed to perturb_func, default={}.
- return_aggregate: boolean
Indicates if an aggregated score should be computed over all instances.
- aggregate_func: callable
Callable that aggregates the scores given an evaluation call.
- disable_warnings: boolean
Indicates whether the warnings are printed, default=False.
- display_progressbar: boolean
Indicates whether a tqdm-progress-bar is printed, default=False.
- default_plot_func: callable
Callable that plots the metrics result.
- eps_min: float
Small constant to prevent division by 0 in relative_stability_objective, default 1e-6.
- return_nan_when_prediction_changes: boolean
When set to true, the metric will be evaluated to NaN if the prediction changes after the perturbation is applied, default=True.
- data_applicability: ClassVar[Set[DataType]] = {DataType.IMAGE, DataType.TABULAR, DataType.TIMESERIES}
- evaluate_batch(model: ModelInterface, x_batch: ndarray, y_batch: ndarray, a_batch: ndarray, **kwargs) ndarray
- Parameters:
- model: tf.keras.Model, torch.nn.Module
A torch or tensorflow model that is subject to explanation.
- x_batch: np.ndarray
4D tensor representing batch of input images.
- y_batch: np.ndarray
1D tensor, representing predicted labels for the x_batch.
- a_batch: np.ndarray, optional
4D tensor with pre-computed explanations for the x_batch.
- kwargs:
Unused.
- Returns:
- ris: np.ndarray
The batched evaluation results.
- evaluation_category: ClassVar[EvaluationCategory] = 'Robustness'
- name: ClassVar[str] = 'Relative Input Stability'
- relative_input_stability_objective(x: ndarray, xs: ndarray, e_x: ndarray, e_xs: ndarray) ndarray
Computes relative input stabilities maximization objective as defined here https://arxiv.org/pdf/2203.06877.pdf by the authors.
- Parameters:
- x: np.ndarray
Batch of images.
- xs: np.ndarray
Batch of perturbed images.
- e_x: np.ndarray
Explanations for x.
- e_xs: np.ndarray
Explanations for xs.
- Returns:
- ris_obj: np.ndarray
RIS maximization objective.
- score_direction: ClassVar[ScoreDirection] = 'lower'