Supervised drift detection on the penguins dataset
Last updated
Last updated
When true outputs/labels are available, we can perform supervised drift detection; monitoring the model's performance directly in order to check for harmful drift. Two detectors ideal for this application are the Fisher’s Exact Test (FET) detector and Cramér-von Mises (CVM) detector detectors.
The FET detector is designed for use on binary data, such as the instance level performance indicators from a classifier (i.e. 0/1 for each incorrect/correct classification). The CVM detector is designed use on continuous data, such as a regressor's instance level loss or error scores.
In this example we will use the offline versions of these detectors, which are suitable for use on batches of data. In many cases data may arrive sequentially, and the user may wish to perform drift detection as the data arrives to ensure it is detected as soon as possible. In this case, the online versions of the FET and CVM detectors can be used, as will be explored in a future example.
The palmerpenguins dataset consists of data on 344 penguins from 3 islands in the Palmer Archipelago, Antarctica. There are 3 different species of penguin in the dataset, and a common task is to classify the the species of each penguin based upon two features, the length and depth of the peguin's bill, or beak.
Artwork by Allison Horst
This notebook requires the seaborn
package for visualization and the palmerpenguins
package to load data. Thse can be installed via pip
:
To download the dataset we use the palmerpenguins package:
The data consists of 333 rows (one row is removed as contains a NaN), one for each penguin. There are 8 features describing the peguins' physical characteristics, their species and sex, the island each resides on, and the year measurements were taken.
For our first example use case, we will perform the popular species classification task. Here we wish the classify the species
based on only bill_length_mm
and bill_depth_mm
. To start we remove the other features and visualise those that remain.
The above plot shows that the Adeilie species can primarily be identified by looking at bill length. Then to further distinguish between Gentoo and Chinstrap, we can look at the bill depth.
Next we separate the data into inputs and outputs, and encoder the species data to integers. Finally, we now split into three data sets; one to train the classifier, one to act a reference set when testing for drift, and one to test for drift on.
For this dataset, a relatively shallow decision tree classifier should be sufficient, and so we train an sklearn
one on the training data.
As expected, the decision tree is able to give acceptable classification accuracy on the train and test sets.
In order to demonstrate use of the drift detectors, we first need to add some artificial drift to the test data X_test
. We add two types of drift here; to create covariate drift we subtract 5mm from the bill length of all the Gentoo penguins. $P(y|\mathbf{X})$ is unchanged here, but clearly we have introduced a delta $\Delta P(\mathbf{X})$. To create concept drift, we switch the labels of the Gentoo and Chinstrap penguins, so that the underlying process $P(y|\mathbf{X})$ is changed.
We now define a utility function to plot the classifier's decision boundaries, and we use this to visualise the reference data set, the test set, and the two new data sets where drift is present.
These plots serve as a visualisation of the differences between covariate drift and concept drift. Importantly, the model accuracies shown above also highlight the fact that not all drift is necessarily malicious, in the sense that even relatively significant drift does not always lead to degradation in a model's performance indicators. For example, the model actually gives a slightly higher accuracy on the covariate drift data set than on the no drift set in this case. Conversely, the concept drift unsuprisingly leads to severely degraded model performance.
Before getting to the main task in this example, monitoring malicious drift with a supervised drift detector, we will first use the MMD detector to check for covariate drift. To do this we initialise it in an unsupervised manner by passing it the input data X_ref
.
Applying this detector on the no drift, covariate drift and concept drift data sets, we see that the detector only detects drift in the covariate drift case. Not detecting drift in the no drift case is desirable, but not detecting drift in the concept drift case is potentially problematic.
The fact that the unsupervised detector above does not detect the severe concept drift demonstrates the motivation for using supervised drift detectors that directly check for malicious drift, which can include malicious concept drift.
To perform supervised drift detection we first need to compute the model's performance indicators. Since this is a classification task, a suitable performance indicator is the instance level binary losses, which are computed below.
As seen above, these losses are binary data, where 0 represents an incorrect classification for each instance, and 1 represents a correct classification.
Since this is binary data, the FET detector is chosen, and initialised on the reference loss data. The alternative
hypothesis is set to less
, meaning we will only flag drift if the proportion of 1s to 0s is reduced compared to the reference data. In other words, we only flag drift if the model's performance has degraded.
Applying this detector to the same three data sets, we see that malicious drift isn't detected in the no drift or covariate drift cases, which is unsurprising since the model performance isn't degraded in these cases. However, with this supervised detector, we now detect the malicious concept drift as desired.
To provide a short example of supervised detection in a regression setting, we now rework the dataset into a regression task, and use the CVM detector on the model's squared error.
Warning: Must have scipy >= 1.7.0 installed for this example.
For a regression task, we take the penguins' flipper length and sex as inputs, and aim to predict the penguins' body mass. Looking at a scatter plot of these features, we can see there is substantial correlation between the chosen inputs and outputs.
Again, we split the dataset into the same three sets; a training set, reference set and test set.
This time we train a linear regressor on the training data, and find that it gives acceptable training and test accuracy.
To generate a copy of the test data with concept drift added, we use the model to create new output data, with a multiplicative factor and some Gaussian noise added. The quality of our synthetic output data is of course affected by the accuracy of the model, but it serves to demonstrate the behavior of the model (and detector) when $P(y|\mathbf{X})$ is changed.
Unsurprisingly, the covariate drift leads to degradation in the model accuracy.
As in the classification example, in order to perform supervised drift detection we need to compute the models performance indicators. For this regression example, the instance level squared errors are used.
The CVM detector is trained on the reference losses:
As desired, the CVM detector does not detect drift on the no drift data, but does on covariate drift data.