quantus.functions.explanation_func module

This modules contains explainer functions which can be used in conjunction with the metrics in the library.

quantus.functions.explanation_func.explain(model, inputs, targets, **kwargs) ndarray

Explain inputs given a model, targets and an explanation method. Expecting inputs to be shaped such as (batch_size, nr_channels, …) or (batch_size, …, nr_channels).

Parameters:
model: torch.nn.Module, tf.keras.Model

A model that is used for explanation.

inputs: np.ndarray

The inputs that ought to be explained.

targets: np.ndarray

The target lables that should be used in the explanation.

kwargs: optional

Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.

xai_lib: string, optional

XAI library: captum, tf-explain or zennit.

method: string, optional

XAI method (used with captum and tf-explain libraries).

attributor: string, optional

XAI method (used with zennit).

xai_lib_kwargs: dictionary, optional

Keyword arguments to be passed to the attribution function.

softmax: boolean, optional

Indicated whether softmax activation in the last layer shall be removed.

channel_first: boolean, optional

Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None.

reduce_axes: tuple

Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.

keepdims: boolean

Indicated if the reduced axes shall be preserved (True) or removed (False).

Returns:
explanation: np.ndarray

Returns np.ndarray of same shape as inputs.

quantus.functions.explanation_func.generate_captum_explanation(model, inputs: ndarray, targets: ndarray, device: str | None = None, **kwargs) ndarray

Generate explanation for a torch model with captum. Parameters ———- model: torch.nn.Module

A model that is used for explanation.

inputs: np.ndarray

The inputs that ought to be explained.

targets: np.ndarray

The target lables that should be used in the explanation.

device: string

Indicated the device on which a torch.Tensor is or will be allocated: “cpu” or “gpu”.

kwargs: optional

Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function. May include xai_lib_kwargs dictionary which includes keyword arguments for a method call.

xai_lib: string

XAI library: captum, tf-explain or zennit.

method: string

XAI method.

xai_lib_kwargs: dict

Keyword arguments to be passed to the attribution function.

channel_first: boolean, optional

Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.

reduce_axes: tuple

Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.

keepdims: boolean

Indicated if the reduced axes shall be preserved (True) or removed (False).

Returns

explanation: np.ndarray

Returns np.ndarray of same shape as inputs.

quantus.functions.explanation_func.generate_tf_explanation(model, inputs: array, targets: array, **kwargs) ndarray

Generate explanation for a tf model with tf_explain. Assumption: Currently only normalised absolute values of explanations supported.

Parameters:
model: tf.keras.Model

A model that is used for explanation.

inputs: np.ndarray

The inputs that ought to be explained.

targets: np.ndarray

The target lables that should be used in the explanation.

kwargs: optional

Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.

method: string, optional

XAI method.

xai_lib_kwargs: dictionary, optional

Keyword arguments to be passed to the attribution function.

softmax: boolean, optional

Indicated whether softmax activation in the last layer shall be removed.

channel_first: boolean, optional

Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None.

reduce_axes: tuple

Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.

keepdims: boolean

Indicated if the reduced axes shall be preserved (True) or removed (False).

Returns
——-
explanation: np.ndarray

Returns np.ndarray of same shape as inputs.

quantus.functions.explanation_func.generate_zennit_explanation(model, inputs: ndarray, targets: ndarray, device: str | None = None, **kwargs) ndarray

Generate explanation for a torch model with zennit.

Parameters:
model: torch.nn.Module

A model that is used for explanation.

inputs: np.ndarray

The inputs that ought to be explained.

targets: np.ndarray

The target lables that should be used in the explanation.

device: string

Indicated the device on which a torch.Tensor is or will be allocated: “cpu” or “gpu”.

kwargs: optional

Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.

attributor: string, optional

XAI method.

xai_lib_kwargs: dictionary, optional

Keyword arguments to be passed to the attribution function.

softmax: boolean, optional

Indicated whether softmax activation in the last layer shall be removed.

channel_first: boolean, optional

Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None.

reduce_axes: tuple

Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.

keepdims: boolean

Indicated if the reduced axes shall be preserved (True) or removed (False).

Returns
——-
explanation: np.ndarray

Returns np.ndarray of same shape as inputs.

quantus.functions.explanation_func.get_explanation(model, inputs, targets, **kwargs)

Generate explanation array based on the type of input model and user specifications. For tensorflow models, tf.explain is used. For pytorch models, either captum or zennit is used, depending on which module is installed. If both are installed, captum is used per default. Setting the xai_lib kwarg to “zennit” uses zennit instead.

Parameters:
model: torch.nn.Module, tf.keras.Model

A model that is used for explanation.

inputs: np.ndarray

The inputs that ought to be explained.

targets: np.ndarray

The target lables that should be used in the explanation.

kwargs: optional

Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.

xai_lib: string, optional

XAI library: captum, tf-explain or zennit.

method: string, optional

XAI method (used with captum and tf-explain libraries).

attributor: string, optional

XAI method (used with zennit).

xai_lib_kwargs: dictionary, optional

Keyword arguments to be passed to the attribution function.

softmax: boolean, optional

Indicated whether softmax activation in the last layer shall be removed.

channel_first: boolean, optional

Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.

reduce_axes: tuple

Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.

keepdims: boolean

Indicated if the reduced axes shall be preserved (True) or removed (False).

Returns
——-
explanation: np.ndarray

Returns np.ndarray of same shape as inputs.