quantus.functions.explanation_func module
This modules contains explainer functions which can be used in conjunction with the metrics in the library.
- quantus.functions.explanation_func.explain(model, inputs, targets, **kwargs) ndarray
Explain inputs given a model, targets and an explanation method. Expecting inputs to be shaped such as (batch_size, nr_channels, …) or (batch_size, …, nr_channels).
- Parameters:
- model: torch.nn.Module, tf.keras.Model
A model that is used for explanation.
- inputs: np.ndarray
The inputs that ought to be explained.
- targets: np.ndarray
The target lables that should be used in the explanation.
- kwargs: optional
Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.
- xai_lib: string, optional
XAI library: captum, tf-explain or zennit.
- method: string, optional
XAI method (used with captum and tf-explain libraries).
- attributor: string, optional
XAI method (used with zennit).
- xai_lib_kwargs: dictionary, optional
Keyword arguments to be passed to the attribution function.
- softmax: boolean, optional
Indicated whether softmax activation in the last layer shall be removed.
- channel_first: boolean, optional
Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None.
- reduce_axes: tuple
Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.
- keepdims: boolean
Indicated if the reduced axes shall be preserved (True) or removed (False).
- Returns:
- explanation: np.ndarray
Returns np.ndarray of same shape as inputs.
- quantus.functions.explanation_func.generate_captum_explanation(model, inputs: ndarray, targets: ndarray, device: str | None = None, **kwargs) ndarray
Generate explanation for a torch model with captum. Parameters ———- model: torch.nn.Module
A model that is used for explanation.
- inputs: np.ndarray
The inputs that ought to be explained.
- targets: np.ndarray
The target lables that should be used in the explanation.
- device: string
Indicated the device on which a torch.Tensor is or will be allocated: “cpu” or “gpu”.
- kwargs: optional
Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function. May include xai_lib_kwargs dictionary which includes keyword arguments for a method call.
- xai_lib: string
XAI library: captum, tf-explain or zennit.
- method: string
XAI method.
- xai_lib_kwargs: dict
Keyword arguments to be passed to the attribution function.
- channel_first: boolean, optional
Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.
- reduce_axes: tuple
Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.
- keepdims: boolean
Indicated if the reduced axes shall be preserved (True) or removed (False).
Returns
- explanation: np.ndarray
Returns np.ndarray of same shape as inputs.
- quantus.functions.explanation_func.generate_tf_explanation(model, inputs: array, targets: array, **kwargs) ndarray
Generate explanation for a tf model with tf_explain. Assumption: Currently only normalised absolute values of explanations supported.
- Parameters:
- model: tf.keras.Model
A model that is used for explanation.
- inputs: np.ndarray
The inputs that ought to be explained.
- targets: np.ndarray
The target lables that should be used in the explanation.
- kwargs: optional
Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.
- method: string, optional
XAI method.
- xai_lib_kwargs: dictionary, optional
Keyword arguments to be passed to the attribution function.
- softmax: boolean, optional
Indicated whether softmax activation in the last layer shall be removed.
- channel_first: boolean, optional
Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None.
- reduce_axes: tuple
Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.
- keepdims: boolean
Indicated if the reduced axes shall be preserved (True) or removed (False).
- Returns
- ——-
- explanation: np.ndarray
Returns np.ndarray of same shape as inputs.
- quantus.functions.explanation_func.generate_zennit_explanation(model, inputs: ndarray, targets: ndarray, device: str | None = None, **kwargs) ndarray
Generate explanation for a torch model with zennit.
- Parameters:
- model: torch.nn.Module
A model that is used for explanation.
- inputs: np.ndarray
The inputs that ought to be explained.
- targets: np.ndarray
The target lables that should be used in the explanation.
- device: string
Indicated the device on which a torch.Tensor is or will be allocated: “cpu” or “gpu”.
- kwargs: optional
Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.
- attributor: string, optional
XAI method.
- xai_lib_kwargs: dictionary, optional
Keyword arguments to be passed to the attribution function.
- softmax: boolean, optional
Indicated whether softmax activation in the last layer shall be removed.
- channel_first: boolean, optional
Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None.
- reduce_axes: tuple
Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.
- keepdims: boolean
Indicated if the reduced axes shall be preserved (True) or removed (False).
- Returns
- ——-
- explanation: np.ndarray
Returns np.ndarray of same shape as inputs.
- quantus.functions.explanation_func.get_explanation(model, inputs, targets, **kwargs)
Generate explanation array based on the type of input model and user specifications. For tensorflow models, tf.explain is used. For pytorch models, either captum or zennit is used, depending on which module is installed. If both are installed, captum is used per default. Setting the xai_lib kwarg to “zennit” uses zennit instead.
- Parameters:
- model: torch.nn.Module, tf.keras.Model
A model that is used for explanation.
- inputs: np.ndarray
The inputs that ought to be explained.
- targets: np.ndarray
The target lables that should be used in the explanation.
- kwargs: optional
Keyword arguments. Pass as “explain_func_kwargs” dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function.
- xai_lib: string, optional
XAI library: captum, tf-explain or zennit.
- method: string, optional
XAI method (used with captum and tf-explain libraries).
- attributor: string, optional
XAI method (used with zennit).
- xai_lib_kwargs: dictionary, optional
Keyword arguments to be passed to the attribution function.
- softmax: boolean, optional
Indicated whether softmax activation in the last layer shall be removed.
- channel_first: boolean, optional
Indicates of the image dimensions are channel first, or channel last. Inferred from the input shape if None.
- reduce_axes: tuple
Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing “()” will keep the original dimensions.
- keepdims: boolean
Indicated if the reduced axes shall be preserved (True) or removed (False).
- Returns
- ——-
- explanation: np.ndarray
Returns np.ndarray of same shape as inputs.