githubEdit

Similarity explanations for 20 newsgroups dataset

In this notebook, we apply the similarity explanation method to a feed forward neural network (FFNN) trained on the 20 newsgroups dataset.

The 20 newsgroups dataset is a corpus of 18846 text documents (emails) divided into 20 sections. The FFNN is trained to classify each document in the correct section. The model uses pre-trained sentence embeddings as input features, which are obtained from raw text using a pretrained transformerarrow-up-right.

Given an input document of interest, the similarity explanation method used here aims to find text documents in the training set that are similar to the document of interest according to "how the model sees them", meaning that the similarity metric makes use of the gradients of the model's loss function with respect to the model's parameters.

The similarity explanation tool supports both pytorch and tensorflow backends. In this example, we will use the pytorch backend. Running this notebook on CPU can be very slow, so GPU is recommended.

A more detailed description of the method can be found herearrow-up-right. The implementation follows Charpiat et al., 2019arrow-up-right and Hanawa et al. 2021arrow-up-right.

#  Installing required sentence transformer
!pip install sentence_transformers
import os
import torch
import string
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
from termcolor import colored
from torch.utils.data import DataLoader
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from alibi.explainers import GradientSimilarity

Utils

Load data

Loading and preparing the 20 newsgroups dataset.

Warning:

Define and train model

We define and train a pytorch classifier using sentence embeddings as inputs.

Define model

Get sentence embeddings and define dataloaders

Train model

Evaluate model

Evaluating the model on train and test set. Since the dataset is well balanced, we only consider accuracy as evaluation metric.

Find similar instances

Selecting a reference set of 1000 random samples from the training set. The GradientSimilarity explainer will find the most similar instances among those. This downsampling step is performed in order to speed up the fit step.

Initializing a GradientSimilarity explainer instance.

Fitting the explainer on the reference data.

Selecting 3 random instances from the test set. We only select documents with less than 1000 characters for visualization purposes.

Getting predictions and explanations for each of the 5 test samples.

Visualizations

Building a dictionary for each sample for visualization purposes. Each dictionary contains:

  • The original text document x (not the embedding representation).

  • The corresponding label y.

  • The corresponding model's prediction pred.

  • The reference instances ordered by similarity X_sim.

  • The corresponding reference labels ordered by similarity y_sim.

  • The corresponding model's predictions for the reference set preds_sim.

Most similar instances

Showing the 3 most similar instances for each of the test instances.

Most similar labels distributions

Showing the average similarity scores for each group of instances in the reference set belonging to the same true class and to the same predicted class.

png
png
png

The plots show how the instances belonging to the same class (and the instances classified by the model as belonging to the same class) of the instance of interest have on average higher similarity scores, as expected.

Last updated

Was this helpful?