Counterfactual Instances on MNIST
Given a test instance $X$, this method can generate counterfactual instances $X^\prime$ given a desired counterfactual class $t$ which can either be a class specified upfront or any other class that is different from the predicted class of $X$.
The loss function for finding counterfactuals is the following:
The first loss term, guides the search towards instances $X^\prime$ for which the predicted class probability $f_t(X^\prime)$ is close to a pre-specified target class probability $p_t$ (typically $p_t=1$). The second loss term ensures that the counterfactuals are close in the feature space to the original test instance.
In this notebook we illustrate the usage of the basic counterfactual algorithm on the MNIST dataset.
Note
To enable support for Counterfactual, you may need to run
pip install alibi[tensorflow]import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os
from time import time
from alibi.explainers import Counterfactual
print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # FalseLoad and prepare MNIST data

Prepare data: scale, reshape and categorize
Define and train CNN model
Evaluate the model on test set
Generate counterfactuals
Original instance:

Counterfactual parameters:
Run counterfactual:
Results:

The counterfactual starting from a 7 moves towards the closest class as determined by the model and the data - in this case a 9. The evolution of the counterfactual during the iterations over $\lambda$ can be seen below (note that all of the following examples satisfy the counterfactual condition):

Typically, the first few iterations find counterfactuals that are out of distribution, while the later iterations make the counterfactual more sparse and interpretable.
Let's now try to steer the counterfactual to a specific class:
Results:

As you can see, by specifying a class, the search process can't go towards the closest class to the test instance (in this case a 9 as we saw previously), so the resulting counterfactual might be less interpretable. We can gain more insight by looking at the difference between the counterfactual and the original instance:

This shows that the counterfactual is stripping out the top part of the 7 to make to result in a prediction of 1 - not very surprising as the dataset has a lot of examples of diagonally slanted 1’s.
Clean up:
Last updated
Was this helpful?

