quantus.helpers.perturbation_utils module

quantus.helpers.perturbation_utils.changed_prediction_indices(model: ModelInterface, x_batch: np.ndarray, x_perturbed: np.ndarray, return_nan_when_prediction_changes: bool) List[int]

Find indices in batch, for which predicted label has changed after applying perturbation. If metric return_nan_when_prediction_changes is False, will return empty list.

Parameters:
return_nan_when_prediction_changes:

Instance attribute of perturbation metrics.

model:
x_batch:

Batch of original inputs provided by user.

x_perturbed:

Batch of inputs after applying perturbation.

Returns:
changed_idx:

List of indices in batch, for which predicted label has changed afer.

quantus.helpers.perturbation_utils.make_changed_prediction_indices_func(return_nan_when_prediction_changes: bool) Callable[[ModelInterface, np.ndarray, np.ndarray], List[int]]

A utility function to improve static analysis.

quantus.helpers.perturbation_utils.make_perturb_func(perturb_func: PerturbFunc, perturb_func_kwargs: Mapping[str, ...] | None, **kwargs) PerturbFunc | functools.partial

A utility function to save few lines of code during perturbation metric initialization.