quantus.metrics.robustness.relative_output_stability module

final class quantus.metrics.robustness.relative_output_stability.RelativeOutputStability(nr_samples: int = 200, abs: bool = False, normalise: bool = False, normalise_func: Callable[[np.ndarray], np.ndarray] | None = None, normalise_func_kwargs: Dict[str, ...] | None = None, perturb_func: Callable | None = None, perturb_func_kwargs: Dict[str, ...] | None = None, return_aggregate: bool = False, aggregate_func: Callable[[np.ndarray], np.float] | None = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-06, default_plot_func: Callable | None = None, return_nan_when_prediction_changes: bool = True, **kwargs)

Bases: Metric[List[float]]

Relative Output Stability leverages the stability of an explanation with respect to the change in the output logits.

ROS(x, x’, ex, ex’) = max frac{||frac{e_x - e_x’}{e_x}||_p} {max (||h(x) - h(x’)||_p, epsilon_{min})},

where h(x) and h(x’) are the output logits for x and x’ respectively.

References:

1) Chirag Agarwal, et. al., 2022. “Rethinking stability for attribution based explanations.”, https://arxiv.org/pdf/2203.06877.pdf

Attributes:
  • _name: The name of the metric.

  • _data_applicability: The data types that the metric implementation currently supports.

  • _models: The model types that this metric can work with.

  • score_direction: How to interpret the scores, whether higher/ lower values are considered better.

  • evaluation_category: What property/ explanation quality that this metric measures.

Attributes:
disable_warnings

A helper to avoid polluting test outputs with warnings.

display_progressbar

A helper to avoid polluting test outputs with tqdm progress bars.

get_params

List parameters of metric.

Methods

__call__(model, x_batch, y_batch[, ...])

For each image x:

batch_preprocess(data_batch)

If data_batch has no a_batch, will compute explanations.

custom_batch_preprocess(*, model, x_batch, ...)

Implement this method if you need custom preprocessing of data or simply for creating/initialising additional attributes or assertions before a data_batch can be evaluated.

custom_postprocess(*, model, x_batch, ...)

Implement this method if you need custom postprocessing of results or additional attributes.

custom_preprocess(*, model, x_batch, ...)

Implement this method if you need custom preprocessing of data, model alteration or simply for creating/initialising additional attributes or assertions.

evaluate_batch(model, x_batch, y_batch, ...)

explain_batch(model, x_batch, y_batch)

Compute explanations, normalise and take absolute (if was configured so during metric initialization.) This method should primarily be used if you need to generate additional explanation in metrics body. It encapsulates typical for Quantus pre- and postprocessing approach. It will do few things: - call model.shape_input (if ModelInterface instance was provided) - unwrap model (if ModelInterface instance was provided) - call explain_func - expand attribution channel - (optionally) normalise a_batch - (optionally) take np.abs of a_batch.

general_preprocess(model, x_batch, y_batch, ...)

Prepares all necessary variables for evaluation.

generate_batches(data, batch_size)

Creates iterator to iterate over all batched instances in data dictionary.

interpret_scores()

Get an interpretation of the scores.

plot([plot_func, show, path_to_save])

Basic plotting functionality for Metric class.

relative_output_stability_objective(h_x, ...)

Computes relative output stabilities maximization objective as defined here https://arxiv.org/pdf/2203.06877.pdf by the authors.

__call__(model: tf.keras.Model | torch.nn.Module, x_batch: np.ndarray, y_batch: np.ndarray, model_predict_kwargs: Dict[str, ...] | None = None, explain_func: Callable | None = None, explain_func_kwargs: Dict[str, ...] | None = None, a_batch: np.ndarray | None = None, device: str | None = None, softmax: bool = False, channel_first: bool = True, batch_size: int = 64, **kwargs) List[float]
For each image x:
  • Generate num_perturbations perturbed xs in the neighborhood of x.

  • Compute explanations e_x and e_xs.

  • Compute relative input output objective, find max value with respect to xs.

  • In practise we just use max over a finite xs_batch.

Parameters:
model: tf.keras.Model, torch.nn.Module

A torch or tensorflow model that is subject to explanation.

x_batch: np.ndarray

4D tensor representing batch of input images

y_batch: np.ndarray

1D tensor, representing predicted labels for the x_batch.

model_predict_kwargs: dict, optional

Keyword arguments to be passed to the model’s predict method.

explain_func: callable, optional

Function used to generate explanations.

explain_func_kwargs: dict, optional

Keyword arguments to be passed to explain_func on call.

a_batch: np.ndarray, optional

4D tensor with pre-computed explanations for the x_batch.

device: str, optional

Device on which torch should perform computations.

softmax: boolean, optional

Indicates whether to use softmax probabilities or logits in model prediction. This is used for this __call__ only and won’t be saved as attribute. If None, self.softmax is used.

channel_first: boolean, optional

Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.

batch_size: int

The batch size to be used.

kwargs:

not used, deprecated

Returns
——-
relative output stability: float, np.ndarray

float in case return_aggregate=True, otherwise np.ndarray of floats

__init__(nr_samples: int = 200, abs: bool = False, normalise: bool = False, normalise_func: Callable[[np.ndarray], np.ndarray] | None = None, normalise_func_kwargs: Dict[str, ...] | None = None, perturb_func: Callable | None = None, perturb_func_kwargs: Dict[str, ...] | None = None, return_aggregate: bool = False, aggregate_func: Callable[[np.ndarray], np.float] | None = None, disable_warnings: bool = False, display_progressbar: bool = False, eps_min: float = 1e-06, default_plot_func: Callable | None = None, return_nan_when_prediction_changes: bool = True, **kwargs)
Parameters:
nr_samples: int

The number of samples iterated, default=200.

abs: boolean

Indicates whether absolute operation is applied on the attribution.

normalise: boolean

Flag stating if the attributions should be normalised

normalise_func: callable

Attribution normalisation function applied in case normalise=True.

normalise_func_kwargs: dict

Keyword arguments to be passed to normalise_func on call, default={}.

perturb_func: callable

Input perturbation function. If None, the default value is used, default=gaussian_noise.

perturb_func_kwargs: dict

Keyword arguments to be passed to perturb_func, default={}.

return_aggregate: boolean

Indicates if an aggregated score should be computed over all instances.

aggregate_func: callable

Callable that aggregates the scores given an evaluation call.

disable_warnings: boolean

Indicates whether the warnings are printed, default=False.

display_progressbar: boolean

Indicates whether a tqdm-progress-bar is printed, default=False.

default_plot_func: callable

Callable that plots the metrics result.

eps_min: float

Small constant to prevent division by 0 in relative_stability_objective, default 1e-6.

return_nan_when_prediction_changes: boolean

When set to true, the metric will be evaluated to NaN if the prediction changes after the perturbation is applied, default=True.

data_applicability: ClassVar[Set[DataType]] = {DataType.IMAGE, DataType.TABULAR, DataType.TIMESERIES}
evaluate_batch(model: ModelInterface, x_batch: ndarray, y_batch: ndarray, a_batch: ndarray, **kwargs) ndarray
Parameters:
model: tf.keras.Model, torch.nn.Module

A torch or tensorflow model that is subject to explanation.

x_batch: np.ndarray

4D tensor representing batch of input images.

y_batch: np.ndarray

1D tensor, representing predicted labels for the x_batch.

a_batch: np.ndarray, optional

4D tensor with pre-computed explanations for the x_batch.

kwargs:

Unused.

Returns:
ros: np.ndarray

A batch of explanations.

evaluation_category: ClassVar[EvaluationCategory] = 'Robustness'
model_applicability: ClassVar[Set[ModelType]] = {ModelType.TF, ModelType.TORCH}
name: ClassVar[str] = 'Relative Output Stability'
relative_output_stability_objective(h_x: ndarray, h_xs: ndarray, e_x: ndarray, e_xs: ndarray) ndarray

Computes relative output stabilities maximization objective as defined here https://arxiv.org/pdf/2203.06877.pdf by the authors.

Parameters:
h_x: np.ndarray

Output logits for x_batch.

h_xs: np.ndarray

Output logits for xs_batch.

e_x: np.ndarray

Explanations for x.

e_xs: np.ndarray

Explanations for xs.

Returns:
ros_obj: np.ndarray

ROS maximization objective.

score_direction: ClassVar[ScoreDirection] = 'lower'