githubEdit

Counterfactuals guided by prototypes on MNIST

This method is described in the Interpretable Counterfactual Explanations Guided by Prototypesarrow-up-right paper and can generate counterfactual instances guided by class prototypes. It means that for a certain instance X, the method builds a prototype for each prediction class using either an autoencoderarrow-up-right or k-d treesarrow-up-right. The nearest prototype class other than the originally predicted class is then used to guide the counterfactual search. For example, in MNIST the closest class to a 7 could be a 9. As a result, the prototype loss term will try to minimize the distance between the proposed counterfactual and the prototype of a 9. This speeds up the search towards a satisfactory counterfactual by steering it towards an interpretable solution from the start of the optimization. It also helps to avoid out-of-distribution counterfactuals with the perturbations driven to a prototype of another class.

The loss function to be optimized is the following:

Loss=cLpred+βL1+L2+LAE+LprotoLoss = cL_{pred} + \beta L_{1} + L_{2} + L_{AE} + L_{proto}

The first loss term relates to the model's prediction function, the following 2 terms define the elastic net regularization while the last 2 terms are optional. The aim of $L_{AE}$ is to penalize out-of-distribution counterfactuals while $L_{proto}$ guides the counterfactual to a prototype. When we only have acces to the model's prediction function and cannot fully enjoy the benefits of automatic differentiation, the prototypes allow us to drop the prediction function loss term $L_{pred}$ and still generate high quality counterfactuals. This drastically reduces the number of prediction calls made during the numerical gradient update step and again speeds up the search.

Other options include generating counterfactuals for specific classes or including trust score constraints to ensure that the counterfactual is close enough to the newly predicted class compared to the original class. Different use cases are illustrated throughout this notebook.

Note

To enable support for CounterfactualProto, you may need to run

pip install alibi[tensorflow]
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs 
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input, UpSampling2D
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
from time import time
from alibi.explainers import CounterfactualProto

print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False

Load and prepare MNIST data

png

Prepare data: scale, reshape and categorize

Define and train CNN model

Evaluate the model on test set

Define and train auto-encoder

Compare original with decoded images

png

Generate counterfactual guided by the nearest class prototype

Original instance:

png

Counterfactual parameters:

Run counterfactual:

Results:

png

The counterfactual starting from a 7 moves towards its closest prototype class: a 9. The evolution of the counterfactual during the first iteration can be seen below:

png

Typically, the first few iterations already steer the 7 towards a 9, while the later iterations make the counterfactual more sparse.

Prototypes defined by the $k$ nearest encoded instances

In the above example, the class prototypes are defined by the average encoding of all instances belonging to the specific class. Instead, we can also select only the $k$ nearest encoded instances of a class to the encoded instance to be explained and use the average over those $k$ encodings as the prototype.

Results for $k$ equals 1:

png

Results for $k$ equals 20:

png

A lower value of $k$ typically leads to counterfactuals that look more like the original instance and less like an average instance of the counterfactual class.

Remove the autoencoder loss term $L_{AE}$

In the previous example, we used both an autoencoder loss term to penalize a counterfactual which falls outside of the training data distribution as well as an encoder loss term to guide the counterfactual to the nearest prototype class. In the next example we get rid of the autoencoder loss term to speed up the counterfactual search and still generate decent counterfactuals:

Results:

png

Specify prototype classes

For multi-class predictions, we might be interested to generate counterfactuals for certain classes while avoiding others. The following example illustrates how to do this:

png

The closest class to the 9 is 4. This is evident by looking at the first counterfactual below. For the second counterfactual, we specified that the prototype class used in the search should be a 7. As a result, a counterfactual 7 instead of a 4 is generated.

png
png

Speed up the counterfactual search by removing the predict function loss term

We can also remove the prediction loss term and still obtain an interpretable counterfactual. This is especially relevant for fully black box models. When we provide the counterfactual search method with a Keras or TensorFlow model, it is incorporated in the TensorFlow graph and evaluated using automatic differentiation. However, if we only have access to the model's prediction function, the gradient updates are numerical and typically require a large number of prediction calls because of the prediction loss term $L_{pred}$. These prediction calls can slow the search down significantly and become a bottleneck. We can represent the gradient of the loss term as follows:

Lpredx=Lpredppx\frac{\partial L_{pred}}{\partial x} = \frac{\partial L_{pred}}{\partial p} \frac{\partial p}{\partial x}

where $L_{pred}$ is the prediction loss term, $p$ the prediction function and $x$ the input features to optimize. For a 28 by 28 MNIST image, the $^{\delta p}/_{\delta x}$ term alone would require a prediction call with batch size 28x28x2 = 1568. By using the prototypes to guide the search however, we can remove the prediction loss term and only make a single prediction at the end of each gradient update to check whether the predicted class on the proposed counterfactual is different from the original class. We do not necessarily need a Keras or TensorFlow auto-encoder either and can use k-d trees to find the nearest class prototypes. Please check out this notebookarrow-up-right for a practical example.

The first example below removes $L_{pred}$ from the loss function to bypass the bottleneck. It illustrates the drastic speed improvements over the black box alternative with numerical gradient evaluation while still producing interpretable counterfactual instances.

png
png

Let us know add the $L_{pred}$ loss term back in the objective function and observe how long it takes to generate a black box counterfactual:

png

Clean up:

Last updated

Was this helpful?