githubEdit

Anchor explanations for fashion MNIST

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from alibi.explainers import AnchorImage

Load and prepare fashion MNIST data

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
print('x_train shape:', x_train.shape, 'y_train shape:', y_train.shape)
x_train shape: (60000, 28, 28) y_train shape: (60000,)
idx = 0
plt.imshow(x_train[idx]);
png

Scale, reshape and categorize data

Define CNN model

Train model

Define superpixels

Function to generate rectangular superpixels for a given image. Alternatively, use one of the built in methods. It is important to have meaningful superpixels in order to generate a useful explanation. Please check scikit-image's segmentation methodsarrow-up-right (felzenszwalb, slic and quickshift built in the explainer) for more information on the built in methods.

png

Define prediction function

Initialize anchor image explainer

Explain a prediction

The explanation returns a mask with the superpixels that constitute the anchor.

Image to be explained:

png

Model prediction:

The predicted category correctly corresponds to the class Pullover:

Label
Description

0

T-shirt/top

1

Trouser

2

Pullover

3

Dress

4

Coat

5

Sandal

6

Shirt

7

Sneaker

8

Bag

9

Ankle boot

Generate explanation:

Show anchor:

png

From the example, it looks like the end of the sleeve alone is sufficient to predict a pullover.

Last updated

Was this helpful?