githubEdit

Similarity explanations for ImageNet

In this notebook, we apply the similarity explanation method to a ResNet model pre-trained on the ImageNet dataset. We use a subset of the ImageNet dataset including 1000 random samples as training set for the explainer. The training set is constructed by picking 100 random images for each of the following classes:

  • 'stingray'

  • 'trilobite'

  • 'centipede'

  • 'slug'

  • 'snail'

  • 'Rhodesian ridgeback'

  • 'beagle'

  • 'golden retriever'

  • 'sea lion'

  • 'espresso'

The test set contains 50 random samples, 5 for each of the classes above. The data set is stored in a public google storage bucket and can be fetched using the utility function fetch_imagenet_10.

Given an input image of interest picked from the test set, the similarity explanation method used here aims to find images in the training set that are similar to the image of interest according to "how the model sees them", meaning that the similarity metric makes use of the gradients of the model's loss function with respect to the model's parameters.

The similarity explanation tool supports both pytorch and tensorflow backends. In this example, we will use the tensorflow backend. Running this notebook on CPU can be very slow, so GPU is recommended.

A more detailed description of the method can be found herearrow-up-right. The implementation follows Charpiat et al., 2019arrow-up-right and Hanawa et al. 2021arrow-up-right.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import accuracy_score
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.losses import categorical_crossentropy
from alibi.datasets import fetch_imagenet_10
from alibi.explainers import GradientSimilarity

Utils

Load data

Fetching and preparing the reduced ImageNet dataset.

png

Load model

Load a pretrained tensorflow model with a ResNet architecture trained on the ImageNet dataset.

Find similar instances

Initializing a GradientSimilarity explainer instance.

Fitting the explainer on the training data.

Selecting 5 random classes out of 10 and 1 random instance per class from the test set (5 test instances in total).

Getting the most similar instance for the each of the 5 test samples.

Visualizations

Building a dictionary for each sample for visualization purposes. Each dictionary contains

  • The original image x (with mean channels added back for visualization).

  • The corresponding label y.

  • The corresponding model's prediction pred.

  • The corresponding reference labels ordered by similarity y_sim.

  • The corresponding model's predictions for the reference set preds_sim.

Most similar instances

Showing the 5 most similar instances for each of the test instances, ordered from the most similar to the least similar.

png

Most similar labels distributions

Showing the average similarity scores for each group of instances in the reference set belonging to the same true class. It can be seen that the higher score corresponds to the class of the original instance, as expected.

png

Last updated

Was this helpful?