Model distillation is a technique that is used to transfer knowledge from a large network to a smaller network. Typically, it consists of training a second model with a simplified architecture on soft targets (the output distributions or the logits) obtained from the original model.
Here, we apply model distillation to obtain harmfulness scores, by comparing the output distributions of the original model with the output distributions of the distilled model, in order to detect adversarial data, malicious data drift or data corruption. We use the following definition of harmful and harmless data points:
Harmful data points are defined as inputs for which the model's predictions on the uncorrupted data are correct while the model's predictions on the corrupted data are wrong.
Harmless data points are defined as inputs for which the model's predictions on the uncorrupted data are correct and the model's predictions on the corrupted data remain correct.
Analogously to the adversarial AE detector, which is also part of the library, the model distillation detector picks up drift that reduces the performance of the classification model.
Moreover, in this example a drift detector that applies two-sample Kolmogorov-Smirnov (K-S) tests to the scores is employed. The p-values obtained are used to assess the harmfulness of the data.
Dataset
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import tensorflow as tf
from alibi_detect.cd import KSDrift
from alibi_detect.ad import ModelDistillation
from alibi_detect.models.tensorflow import scale_by_instance
from alibi_detect.utils.fetching import fetch_tf_model, fetch_detector
from alibi_detect.utils.tensorflow import predict_batch
from alibi_detect.saving import save_detector
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c
Analogously to the adversarial AE detector, which uses an autoencoder to reproduce the output distribution of a classifier and produce adversarial scores, the model distillation detector achieves the same goal by using a simple classifier in place of the autoencoder. This approach is more flexible since it bypasses the instance's generation step, and it can be applied in a straightforward way to a variety of data sets such as text or time series.
We can use the adversarial scores produced by the Model Distillation detector in the context of drift detection. The score function of the detector becomes the preprocessing function for the drift detector. The K-S test is then a simple univariate test between the adversarial scores of the reference batch and the test data. Higher adversarial scores indicate more harmful drift. Importantly, a harmfulness detector flags malicious data drift. We can fetch the pretrained model distillation detector from a Google Cloud Bucket or train one from scratch:
Calculate scores. We split the corrupted data into harmful and harmless data and visualize the harmfulness scores for various values of corruption severity.
dfs = {}
score_drift = {
1: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
2: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
3: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
4: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
5: {'all': [], 'harm': [], 'noharm': [], 'acc': 0},
}
y_pred = predict_batch(X_test, clf, batch_size=256).argmax(axis=1)
score_x = ad.score(X_test, batch_size=256)
for s in severities:
print('Loading corrupted data. Severity = {}'.format(s))
X_corr, y_corr = fetch_cifar10c(corruption=corruptions, severity=s, return_X_y=True)
print('Preprocess data...')
X_corr = X_corr.astype('float32') / 255
X_corr = scale_by_instance(X_corr)
print('Make predictions on corrupted dataset...')
y_pred_corr = predict_batch(X_corr, clf, batch_size=1000).argmax(axis=1)
print('Compute adversarial scores on corrupted dataset...')
score_corr = ad.score(X_corr, batch_size=256)
labels_corr = np.zeros(score_corr.shape[0])
repeat = y_corr.shape[0] // y_test.shape[0]
y_pred_repeat = np.tile(y_pred, (repeat,))
# malicious/harmful corruption: original prediction correct but
# prediction on corrupted data incorrect
idx_orig_right = np.where(y_pred_repeat == y_corr)[0]
idx_corr_wrong = np.where(y_pred_corr != y_corr)[0]
idx_harmful = np.intersect1d(idx_orig_right, idx_corr_wrong)
# harmless corruption: original prediction correct and prediction
# on corrupted data correct
labels_corr[idx_harmful] = 1
labels = np.concatenate([np.zeros(X_test.shape[0]), labels_corr]).astype(int)
idx_corr_right = np.where(y_pred_corr == y_corr)[0]
idx_harmless = np.intersect1d(idx_orig_right, idx_corr_right)
# Split corrupted inputs in harmful and harmless
X_corr_harm = X_corr[idx_harmful]
X_corr_noharm = X_corr[idx_harmless]
# Store adversarial scores for harmful and harmless data
score_drift[s]['all'] = score_corr
score_drift[s]['harm'] = score_corr[idx_harmful]
score_drift[s]['noharm'] = score_corr[idx_harmless]
score_drift[s]['acc'] = accuracy(y_corr, y_pred_corr)
print('Compute p-values')
for j in range(nb_batches):
ps = []
pvs_harm = []
pvs_noharm = []
for p in np.arange(0, 1, 0.1):
# Sampling a batch of size `batch_size` where a fraction p of the data
# is corrupted harmful data and a fraction 1 - p is non-corrupted data
X_batch_harm, _ = sample_batch(X_test, X_corr_harm, batch_size, p)
# Sampling a batch of size `batch_size` where a fraction p of the data
# is corrupted harmless data and a fraction 1 - p is non-corrupted data
X_batch_noharm, perc = sample_batch(X_test, X_corr_noharm, batch_size, p)
# Calculating p-values for the harmful and harmless data by applying
# K-S test on the adversarial scores
pv_harm = cd.score(X_batch_harm)
pv_noharm = cd.score(X_batch_noharm)
ps.append(perc * 100)
pvs_harm.append(pv_harm[0])
pvs_noharm.append(pv_noharm[0])
if j == 0:
df = pd.DataFrame({'p': ps})
df['pvalue_harm_{}'.format(j)] = pvs_harm
df['pvalue_noharm_{}'.format(j)] = pvs_noharm
for name in ['pvalue_harm', 'pvalue_noharm']:
df[name + '_mean'] = df[[col for col in df.columns if name in col]].mean(axis=1)
df[name + '_std'] = df[[col for col in df.columns if name in col]].std(axis=1)
df[name + '_max'] = df[[col for col in df.columns if name in col]].max(axis=1)
df[name + '_min'] = df[[col for col in df.columns if name in col]].min(axis=1)
df.set_index('p', inplace=True)
dfs[s] = df
Plot scores
We now plot the mean scores and standard deviations per severity level. The plot shows the mean harmfulness scores (lhs) and ResNet-32 accuracies (rhs) for increasing data corruption severity levels. Level 0 corresponds to the original test set. Harmful scores are scores from instances which have been flipped from the correct to an incorrect prediction because of the corruption. Not harmful means that a correct prediction was unchanged after the corruption.
mu_noharm, std_noharm = [], []
mu_harm, std_harm = [], []
acc = [clf_accuracy['original']]
for k, v in score_drift.items():
mu_noharm.append(v['noharm'].mean())
std_noharm.append(v['noharm'].std())
mu_harm.append(v['harm'].mean())
std_harm.append(v['harm'].std())
acc.append(v['acc'])
In order to simulate a realistic scenario, we perform a K-S test on batches of instance which are increasingly contaminated with corrupted data. The following steps are implemented:
We randomly pick n_ref=1000 samples from the non-currupted test set to be used as a reference set in the initialization of the K-S drift detector.
We sample batches of data of size batch_size=100 contaminated with an increasing number of harmful corrupted data and harmless corrupted data.
The K-S detector predicts whether drift occurs between the contaminated batches and the reference data and returns the p-values of the test.
We observe that contamination of the batches with harmful data reduces the p-values much faster than contamination with harmless data. In the latter case, the p-values remain above the detection threshold even when the batch is heavily contaminated
We repeat the test for 100 randomly sampled batches and we plot the mean and the maximum p-values for each level of severity and contamination below. We can see from the plot that the detector is able to clearly detect a batch contaminated with harmful data compared to a batch contaminated with harmless data when the percentage of currupted data reaches 20%-30%.
#| scrolled: false
for s in severities:
nrows = 1
ncols = 2
figsize = (15, 8)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
title0 = ('Mean p-values for various percentages of corrupted data. \n'
' Nb of batches = {}, batch size = {}, severity = {}'.format(
nb_batches, batch_size, s))
title1 = ('Maximum p-values for various percentages of corrupted data. \n'
' Nb of batches = {}, batch size = {}, severity = {}'.format(
nb_batches, batch_size, s))
dfs[s][['pvalue_harm_mean', 'pvalue_noharm_mean']].plot(ax=ax[0], title=title0)
dfs[s][['pvalue_harm_max', 'pvalue_noharm_max']].plot(ax=ax[1], title=title1)
for a in ax:
a.set_xlabel('Percentage of corrupted data')
a.set_ylabel('p-value')