Learned drift detectors on CIFAR-10
Under the hood drift detectors leverage a function (also known as a test-statistic) that is expected to take a large value if drift has occurred and a low value if not. The power of the detector is partly determined by how well the function satisfies this property. However, specifying such a function in advance can be very difficult. In this example notebook we consider two ways in which a portion of the available data may be used to learn such a function before then applying it on the held out portion of the data to test for drift.
Detecting drift with a learned classifier
The classifier-based drift detector simply tries to correctly distinguish instances from the reference data vs. the test set. The classifier is trained to output the probability that a given instance belongs to the test set. If the probabilities it assigns to unseen tests instances are significantly higher (as determined by a Kolmogorov-Smirnov test) to those it assigns to unseen reference instances then the test set must differ from the reference set and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used for the significance test. Note that a new classifier is trained for each test set or even each fold within the test set.
Backend
The method works with both the PyTorch and TensorFlow frameworks. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
Dataset
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
Load data
Original CIFAR-10 data:
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
Let's pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
We split the original test set in a reference dataset and a dataset which should not be flagged as drift. We also split the corrupted data by corruption type:
We can visualise the same instance for each corruption type:
Detect drift with a TensorFlow classifier
Single fold
We use a simple classification model and try to distinguish between the reference data and the corrupted test sets. The detector defaults to binarize=False
which means a Kolmogorov-Smirnov test will be used to test for significant disparity between continuous model predictions (e.g. probabilities or logits). Initially we'll test at a significance level of $p=0.05$, use $75$% of the shuffled reference and test data for training and evaluate the detector on the remaining $25$%. We only train for 1 epoch.
If needed, the detector can be saved and loaded with save_detector
and load_detector
:
Let's check whether the detector thinks drift occurred on the different test sets and time the prediction calls:
As expected, drift was only detected on the corrupted datasets and the classifier could easily distinguish the corrupted from the reference data.
Use all the available data via cross-validation
So far we've only used $25$% of the data to detect the drift since $75$% is used for training purposes. At the cost of additional training time we can however leverage all the data via stratified cross-validation. We just need to set the number of folds and keep everything else the same. So for each test set n_folds
models are trained, and the out-of-fold predictions combined for the significance test:
Detecting drift with a learned kernel
An alternative to training a classifier to output high probabilities for instances from the test window and low probabilities for instances from the reference window is to learn a kernel that outputs high similarities between instances from the same window and low similarities between instances from different windows. The kernel may then be used within an MMD-test for drift. Liu et al. (2020) propose this learned approach and note that it is in fact a generalisation of the above classifier-based method. However, in this case we can train the kernel to directly optimise an estimate of the detector's power, which can result in superior performance.
Detect drift with a learned PyTorch kernel
This can be implemented as shown below. We use Pytorch instead of TensorFlow this time for the sake of variety. Because we are dealing with images we give our projection $\Phi$ a convolutional architecture.
We may then specify a DeepKernel
in the following manner. By default GaussianRBF
kernels are used for $k_a$ and $k_b$ and here we specify $\epsilon=0.01$, but we could alternatively set eps='trainable'
.
Since our PyTorch encoder expects the images in a (batch size, channels, height, width) format, we transpose the data. Note that this step could also be passed to the drift detector via the preprocess_fn
kwarg:
We then pass the kernel to the LearnedKernelDrift
detector. By default $75%$ of the data is used to train the kernel and the MMD-test is performed on the other $25%$.
Again, the detector can be saved and loaded:
Finally, lets make some predictions with the detector:
Last updated
Was this helpful?