The detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions p and q based on the mean embeddings $\mu_{p}$ and $\mu_{q}$ in a reproducing kernel Hilbert space $F$:
MMD(F,p,q)=∣∣μp−μq∣∣F2
We can compute unbiased estimates of $MMD^2$ from the samples of the 2 distributions after applying the kernel trick. We use by default a , but users are free to pass their own kernel of preference to the detector. We obtain a $p$-value via a on the values of $MMD^2$. This method is also described in .
Backend
The method is implemented in both the PyTorch and TensorFlow frameworks with support for CPU and GPU. Various preprocessing steps are also supported out-of-the box in Alibi Detect for both frameworks and illustrated throughout the notebook. Alibi Detect does however not install PyTorch for you. Check the how to do this.
Dataset
consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from alibi_detect.cd import MMDDrift
from alibi_detect.models.tensorflow import scale_by_instance
from alibi_detect.utils.fetching import fetch_tf_model
from alibi_detect.saving import save_detector, load_detector
from alibi_detect.datasets import fetch_cifar10c, corruption_types_cifar10c
We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the MMD test. We also split the corrupted data by corruption type:
As expected, drift was only detected on the corrupted datasets.
BBSDs
X_ref_bbsds = scale_by_instance(X_ref)
X_h0_bbsds = scale_by_instance(X_h0)
X_c_bbsds = [scale_by_instance(X_c[i]) for i in range(n_corr)]
Initialisation of the drift detector. Here we use the output of the softmax layer to detect the drift, but other hidden layers can be extracted as well by setting 'layer' to the index of the desired hidden layer in the model:
from alibi_detect.cd.tensorflow import HiddenOutput
# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=HiddenOutput(clf, layer=-1), batch_size=128)
# initialise drift detector
cd = MMDDrift(X_ref_bbsds, backend='tensorflow', p_val=.05,
preprocess_fn=preprocess_fn, n_permutations=100)
Again drift is only flagged on the perturbed data.
Detect drift with PyTorch backend
We can do the same thing using the PyTorch backend. We illustrate this using the randomly initialized encoder as preprocessing step:
import torch
import torch.nn as nn
# set random seed and device
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
Since our PyTorch encoder expects the images in a (batch size, channels, height, width) format, we transpose the data:
from alibi_detect.cd.pytorch import preprocess_drift
# define encoder
encoder_net = nn.Sequential(
nn.Conv2d(3, 64, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(128, 512, 4, stride=2, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(2048, encoding_dim)
).to(device).eval()
# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=encoder_net, device=device, batch_size=512)
# initialise drift detector
cd = MMDDrift(X_ref_pt, backend='pytorch', p_val=.05,
preprocess_fn=preprocess_fn, n_permutations=100)
# we can also save/load an initialised PyTorch based detector
filepath = 'detector_pt' # change to directory where detector is saved
save_detector(cd, filepath)
cd = load_detector(filepath)
make_predictions(cd, X_h0_pt, X_c_pt, corruption)
The drift detector will attempt to use the GPU if available and otherwise falls back on the CPU. We can also explicitly specify the device. Let's compare the GPU speed up with the CPU implementation:
Notice the over 30x acceleration provided by the GPU.
Similar to the TensorFlow implementation, PyTorch can also use the hidden layer output from a pretrained model for the preprocessing step via:
from alibi_detect.cd.pytorch import HiddenOutput
We are trying to detect data drift on high-dimensional (32x32x3) data using a multivariate MMD permutation test. It therefore makes sense to apply dimensionality reduction first. Some dimensionality reduction methods also used in are readily available: a randomly initialized encoder (UAE or Untrained AutoEncoder in the paper), BBSDs (black-box shift detection using the classifier's softmax outputs) and PCA (using scikit-learn).
For BBSDs, we use the classifier's softmax outputs for black-box shift detection. This method is based on . The ResNet classifier is trained on data standardised by instance so we need to rescale the data.