Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
Loading...
In this notebook we show how to detect drift on ECG data given a specific context using the context-aware MMD detector (Cobb and Van Looveren, 2022). Consider the following simple example: we have a heatbeat monitoring system which is trained on a wide variety of heartbeats sampled from people of all ages across a variety of activities (e.g. rest or running). Then we deploy the system to monitor individual people during certain activities. The distribution of the heartbeats monitored during deployment will then be drifting against the reference data which resembles the full training distribution, simply because only individual people in a specific setting are being tracked. However, this does not mean that the system is not working and requires re-training. We are instead interested in flagging drift given the relevant context such as the person's characteristics (e.g. age or medical history) and the activity. Traditional drift detectors cannot flexibly deal with this setting since they rely on the i.i.d. assumption when sampling the reference and test sets. The context-aware detector however allows us to pass this context to the detector and flag drift appropriately. More generally, the context-aware drift detector detects changes in the data distribution which cannot be attributed to a permissible change in the context variable. On top of that, the detector allows you to understand which subpopulations are present in both the reference and test data which provides deeper insights into the distribution underlying the test data.
Useful context (or conditioning) variables for the context-aware drift detector include but are not limited to:
Domain or application specific contexts such as the time of day or the activity (e.g. running or resting).
Conditioning on the relative prevalences of known subpopulations, such as the frequency of different types of heartbeats. It is important to note that while the relative frequency of each subpopulation (e.g. the different heartbeat types) might change, the distribution underlying each individual subpopulation (e.g. each specific type of heartbeat) cannot change.
Conditioning on model predictions. Assume we trained a classifier which detects arrhythmia, then we can provide the classifier model predictions as context and understand if, given the model prediction, the data comes from the same underlying distribution as the reference data or not.
Conditioning on model uncertainties which would allow increases in model uncertainty due to drift into familiar regions of high aleatoric uncertainty (often fine) to be distinguished from that into unfamiliar regions of high epistemic uncertainty (often problematic).
The following settings will be showcased throughout the notebook:
A change in the prevalences of subpopulations (i.e. different types of heartbeats as determined by an unsupervised clustering model or an ECG classifier) which are also present in the reference data is observed. Contrary to traditional drift detection approaches, the context-aware detector does not flag drift as this change in frequency of various heartbeats is permissible given the context provided.
A change in the underlying distribution underlying one or more subpopulations takes place. While we allow changes in the prevalences of the subpopulations accounted for by the context variable, we do not allow changes of the subpopulations themselves. If for instance the ECGs are corrupted by noise on the sensor measurements, we want to flag drift.
We also show how to condition the detector on different context variables such as the ECG classifier model predictions, cluster membership by an unsupervised clustering algorithm and timestamps.
Under setting 1. we want our detector to be well-calibrated (a controlled False Positive Rate (FPR) and more generally a p-value which is uniformly distributed between 0 and 1) while under setting 2. we want our detector to be powerful and flag drift. Lastly, we show how the detector can help you to understand the connection between the reference and test data distributions better.
The dataset contains 5000 ECG’s, originally obtained from Physionet from the BIDMC Congestive Heart Failure Database, record chf07. The data has been pre-processed in 2 steps: first each heartbeat is extracted, and then each beat is made equal length via interpolation. The data is labeled and contains 5 classes. The first class $N$ which contains almost 60% of the observations is seen as normal while the others are supraventricular ectopic beats ($S$), ventricular ectopic beats ($V$), fusion beats ($F$) and unknown beats ($Q$).
The notebook requires the torch
and statsmodels
packages to be installed, which can be done via pip
:
Before we start let's fix the random seeds for reproducibility:
First we load the data, show the distribution across the ECG classes and visualise some ECGs from each class.
We can see that most heartbeats can be classified as normal, followed by the unknown class. We will now sample 500 heartbeats to train a simple ECG classifier. Importantly, we leave out the $F$ and $V$ classes which are used to detect drift. First we define a helper function to sample data.
We use a prop_train fraction of all samples to train the classifier and then remove instances from the $F$ and $V$ classes. The rest of the data is used by our drift detectors.
Now we define and train our classifier on the training set.
Let's evaluate out classifier on both the training and drift portions of the datasets.
We start with an example where no drift occurs and the reference and test data are both sampled randomly from all classes present in the reference data (classes 0, 1 and 3). Under this scenario, we expect no drift to be detected by either a normal MMD detector or by the context-aware MMD detector.
Before we can start using the context-aware drift detector, first we need to define our context variable. In our experiments we allow the relative prevalences of subpopulations (i.e. the relative frequency of different types of hearbeats also present in the reference data) to vary while the distributions underlying each of the subpopulations remain unchanged. To achieve this we condition on the prediction probabilities of the classifier we trained earlier to distinguish the different types of ECGs. We can do this because the prediction probabilities can account for the frequency of occurrence of each of the heartbeat types (be it imperfectly given our classifier makes the occasional mistake).
The below figure of the Q-Q (Quantile-Quantile) plots of a random sample from the uniform distribution U[0,1] against the obtained p-values from the vanilla and context-aware MMD detectors illustrate how well both detectors are calibrated. A perfectly calibrated detector should have a Q-Q plot which closely follows the diagonal. Only the middle plot in the grid shows the detector's p-values. The other plots correspond to n_runs p-values actually sampled from U[0,1] to contextualise how well the central plot follows the diagonal given the limited number of samples.
As expected we can see that both the normal MMD and the context-aware MMD detectors are well-calibrated.
We now focus our attention on a more realistic problem where the relative frequency of one or more subpopulations (i.e. types of hearbeats) is changing while the underlying subpopulation distribution stays the same. This would be the expected setting when we monitor the heartbeat of a specific person (e.g. only normal heartbeats) and we don't want to flag drift.
While the usual MMD detector only returns very low p-values (mostly 0), the context-aware MMD detector remains calibrated.
In the following example we change the distribution of one or more of the underlying subpopulations (i.e. the different types of heartbeats). Notice that now we do want to flag drift since our context variable, which permits changes in relative subpopulation prevalences, can no longer explain the change in distribution.
We will again sample from the normal heartbeats, but now we will add random noise to a fraction of the extracted heartbeats to change the distribution. This could be the result of an error with some of the sensors. The perturbation is illustrated below:
As we can see from the Q-Q and power of the detector, the changes in the subpopulation are easily detected:
We now use the cluster membership probabilities of a Gaussian mixture model which is fit on the training instances as context variables instead of the model predictions. We will test both the calibration when the frequency of the subpopulations (the cluster memberships) changes as well as the power when the $F$ and $V$ heartbeats are included.
The test statistic $\hat{t}$ of the context-aware MMD detector can be formulated as follows: $\hat{t} = \langle K_{0,0}, W_{0,0} \rangle + \langle K_{1,1}, W_{1,1} \rangle -2\langle K_{0,1}, W_{0,1}\rangle$ where $0$ refers to the reference data, $1$ to the test data, and $W_{.,.}$ and $K_{.,.}$ are the weight and kernel matrices, respectively. The weight matrices $W_{.,.}$ allow us to focus on the distribution's subpopulations of interest. Reference instances which have similar contexts as the test data will have higher values for their entries in $W_{0,1}$ than instances with dissimilar contexts. We can therefore interpret $W_{0,1}$ as the coupling matrix between instances in the reference and the test sets. This allows us to investigate which subpopulations from the reference set are present and which are missing in the test data. If we also have a good understanding of the model performance on various subpopulations of the reference data, we could even try and use this coupling matrix to roughly proxy model performance on the unlabeled test instances. Note that in this case we would require labels from the reference data and make sure the reference instances come from the validation, not the training set.
In the following example we only pick 1 type of heartbeat (the normal one) to be present in the test set while 3 types are present in the reference set. We can then investigate via the coupling matrix whether the test statistic $\hat{t}$ focused on the right types of heartbeats in the reference data via $W_{0,1}$. More concretely, we can sum over the columns (the test instances) of $W_{0,1}$ and check which reference instances obtained the highest weights.
As expected no drift was detected since the test set only contains normal heartbeats. We now sort the weights of w_ref
in descending order. We expect the top 400 entries to be fairly high and consistent since these represent the normal heartbeats in the reference set. Afterwards, the weight attribution to the other instances in the reference set should be low. The plot below confirms that this is indeed what happens.
The dataset consists of nicely extracted and aligned ECGs of 140 data points for each observation. However in reality it is likely that we will continuously or periodically observe instances which are not nicely aligned. We could however assign a timestamp to the data (e.g. starting from a peak) and use time as the context variable. This is illustrated in the example below.
First we create a new dataset where we split each instance in slices of non-overlapping ECG segments. Each of the segments will have an associated timestamp as context variable. Then we can check the calibration under no change (besides the time-varying behaviour which is accounted for) as well as the power for ECG segments where we add incorrect time stamps to some of the segments.
In this notebook we show how to detect drift on text data given a specific context using the (). Consider the following simple example: the upcoming elections result in an increase of political news articles compared to other topics such as sports or science. Given the context (the elections), it is however not surprising that we observe this uptick. Moreover, assume we have a machine learning model which is trained to classify news topics, and this model performs well on political articles. So given that we fully expect this uptick to occur given the context, and that our model performs fine on the political news articles, we do not want to flag this type of drift in the data. This setting corresponds more closely to many real-life settings than traditional drift detection where we make the assumption that both the reference and test data are i.i.d. samples from their underlying distributions.
In our news topics example, each different topic such as politics, sports or weather represents a subpopulation of the data. Our context-aware drift detector can then detect changes in the data distribution which cannot be attributed to a change in the relative prevalences of these subpopulations, which we deem permissible. As a cherry on the cake, the context-aware detector allows you to understand which subpopulations are present in both the reference and test data. This allows you to obtain deep insights into the distribution underlying the test data.
Useful context (or conditioning) variables for the context-aware drift detector include but are not limited to:
Domain or application specific contexts such as the time of day or the weather.
Conditioning on the relative prevalences of known subpopulations, such as the frequency of political articles. It is important to note that while the relative frequency of each subpopulation might change, the distribution underlying each subpopulation cannot change.
Conditioning on model predictions. Assume we trained a classifier which tries to figure out which news topic an article belongs to. Given our model predictions we then want to understand whether our test data follows the same underlying distribution as reference instances with similar model predictions. This conditioning would also be useful in case of trending news topics which cause the model prediction distribution to shift but not necessarily the distribution within each of the news topics.
Conditioning on model uncertainties which would allow increases in model uncertainty due to drift into familiar regions of high aleatoric uncertainty (often fine) to be distinguished from that into unfamiliar regions of high epistemic uncertainty (often problematic).
The following settings will be illustrated throughout the notebook:
A change in the prevalences of subpopulations (i.e. news topics) relative to their prevalences in the training data. Contrary to traditional drift detection approaches, the context-aware detector does not flag drift as this change in frequency of news topics is permissible given the context provided (e.g. more political news articles around elections).
A change in the underlying distribution of one or more subpopulations takes place. While we allow changes in the prevalence of the subpopulations accounted for by the context variable, we do not allow changes of the subpopulations themselves. Let's assume that a newspaper usually has a certain tone (e.g. more conservative) when it comes to politics. If this tone changes (to less conservative) around elections (increased frequency of political news articles), then we want to flag it as drift since the change cannot be attributed to the context given to the detector.
A change in the distribution as we observe a previously unseen news topic. A newspaper might for instance add a classified ads section, which was not present in the reference data.
Under setting 1. we want our detector to be well-calibrated (a controlled False Positive Rate (FPR) and more generally a p-value which is uniformly distributed between 0 and 1) while under settings 2. and 3. we want our detector to be powerful and flag the drift. Lastly, we show how the detector can help you to understand the connection between the reference and test data distributions better.
We use the which contains about 18,000 newsgroups post across 20 topics, including politics, science sports or religion.
The notebook requires the umap-learn
, torch
, sentence-transformers
, statsmodels
, seaborn
and datasets
packages to be installed, which can be done via pip
:
Before we start let's fix the random seeds for reproducibility:
First we load the data, show which classes (news topics) are present and what an instance looks like.
Let's take a look at an instance from the dataset:
We define respectively a generic clustering model using UMAP, a model to embed the text input using pre-trained SentenceTransformers embeddings, a text classifier and a utility function to place the data on the right device.
First we train a classifier on a small subset of the data. The aim of the classifier is to predict the news topic of each instance. Below we define a few simple training and evaluation functions.
We now split the data in 2 sets. The first set (x_train
) we will use to train our text classifier, and the second set (x_drift
) is held out to test our drift detector on.
Let's train our classifier. The classifier consists of a simple MLP head on top of a pre-trained SentenceTransformer model as the backbone. The SentenceTransformer remains frozen during training and only the MLP head is finetuned.
We start with an example where no drift occurs and the reference and test data are both sampled randomly from all news topics. Under this scenario, we expect no drift to be detected by either a normal MMD detector or by the context-aware MMD detector.
First we define some helper functions. The first one visualises the clustered text data while the second function samples disjoint reference and test sets with a specified number of instances per class (i.e. per news topic).
We first define the embedding model using the pre-trained SentenceTransformer embeddings and then embed both the reference and test sets.
By applying UMAP clustering on the SentenceTransformer embeddings, we can visually inspect the various news topic clusters. Note that we fit the clustering model on the held out data first, and then make predictions on the reference and test sets.
We can visually see that the reference and test set are made up of similar clusters of data, grouped by news topic. As a result, we would not expect drift to be flagged. If the data distribution did not change, we can expect the p-value distribution of our statistical test to be uniformly distributed between 0 and 1. So let's see if this assumption holds.
Importantly, first we need to define our context variable for the context-aware MMD detector. In our experiments we allow the relative prevalences of subpopulations to vary while the distributions underlying each of the subpopulations remain unchanged. To achieve this we condition on the prediction probabilities of the classifier we trained earlier to distinguish each of the 20 different news topics. We can do this because the prediction probabilities can account for the frequency of occurrence of each of the topics (be it imperfectly given our classifier makes the occasional mistake).
Before we set off our experiments, we embed all the instances in x_drift
and compute all contexts c_drift
so we don't have to call our transformer model every single pass in the for loop.
As expected we can see that both the normal MMD and the context-aware MMD detectors are well-calibrated.
We now focus our attention on a more realistic problem where the relative frequency of one or more subpopulations (i.e. news topics) is changing in a way which can be attributed to external events. Importantly, the distribution underlying each subpopulation (e.g. the distribution of hockey news itself) remains unchanged, only its frequency changes.
In our example we assume that the World Series and Stanley Cup coincide on the calendar leading to a spike in news articles on respectively baseball and hockey. Furthermore, there is not too much news on Mac or Windows since there are no new releases or products planned anytime soon.
While the context-aware detector remains well calibrated, the MMD detector consistently flags drift (low p-values). Note that this is the expected behaviour since the vanilla MMD detector cannot take any external context into account and correctly detects that the reference and test data do not follow the same underlying distribution.
We can also easily see this on the plot below where the p-values of the context-aware detector are uniformly distributed while the MMD detector's p-values are consistently close to 0. Note that we limited the y-axis range to make the plot easier to read.
In the following example we change the distribution of one or more of the underlying subpopulations. Notice that now we do want to flag drift since our context variable, which permits changes in relative subpopulation prevalences, can no longer explain the change in distribution.
Imagine our news topic classification model is not as granular as before and instead of the 20 categories only predicts the 6 super classes, organised by subject matter:
Computers: comp.graphics; comp.os.ms-windows.misc; comp.sys.ibm.pc.hardware; comp.sys.mac.hardware; comp.windows.x
Recreation: rec.autos; rec.motorcycles; rec.sport.baseball; rec.sport.hockey
Science: sci.crypt; sci.electronics; sci.med; sci.space
Miscellaneous: misc.forsale
Politics: talk.politics.misc; talk.politics.guns; talk.politics.mideast
Religion: talk.religion.misc; talk.atheism; soc.religion.christian
What if baseball and hockey become less popular and the distribution underlying the Recreation class changes? We will want to detect this as the change in distributions of the subpopulations (the 6 super classes) cannot be explained anymore by the context variable.
In order to reuse our pretrained classifier for the super classes, we add the following helper function to map the predictions on the super classes and return one-hot encoded predictions over the 6 super classes. Note that our context variable now changes from a probability distribution over the 20 news topics to a one-hot encoded representation over the 6 super classes.
We can see that the context-aware detector is powerful to detect changes in the distributions of the subpopulations.
Next we illustrate the effectiveness of the context-aware detector to detect new topics which are not present in the reference data. Obviously we also want to flag drift in this case. As an example we introduce movie reviews in the test data.
So far we have conditioned the context-aware detector on the model predictions. There are however many other useful contexts possible. One such example would be to condition on the predictions of an unsupervised clustering algorithm. To facilitate this, we first apply kernel PCA on the embedding vectors, followed by a Gaussian mixture model which clusters the data into 6 classes (same as the super classes). We will test both the calibration under the null hypothesis (no distribution change) as well as the power when a new topic (movie reviews) is injected.
Next we change the number of instances in each cluster between the reference and test sets. Note that we do not alter the underlying distribution of each of the clusters, just the frequency.
Now we run the experiment and show the context-aware detector's calibration when changing the cluster frequencies. We also show how the usual MMD detector will consistently flag drift. Furthermore, we inject instances from the movie reviews dataset and illustrate that the context-aware detector remains powerful when the underlying cluster distribution changes (by including a previously unseen topic).
The test statistic $\hat{t}$ of the context-aware MMD detector can be formulated as follows: $\hat{t} = \langle K_{0,0}, W_{0,0} \rangle + \langle K_{1,1}, W_{1,1} \rangle -2\langle K_{0,1}, W_{0,1}\rangle$ where $0$ refers to the reference data, $1$ to the test data, and $W_{.,.}$ and $K_{.,.}$ are the weight and kernel matrices, respectively. The weight matrices $W_{.,.}$ allow us to focus on the distribution's subpopulations of interest. Reference instances which have similar contexts as the test data will have higher values for their entries in $W_{0,1}$ than instances with dissimilar contexts. We can therefore interpret $W_{0,1}$ as the coupling matrix between instances in the reference and the test sets. This allows us to investigate which subpopulations from the reference set are present and which are missing in the test data. If we also have a good understanding of the model performance on various subpopulations of the reference data, we could even try and use this coupling matrix to roughly proxy model performance on the unlabeled test instances. Note that in this case we would require labels from the reference data and make sure the reference instances come from the validation, not the training set.
In the following example we only pick 2 classes to be present in the test set while all 20 are present in the reference set. We can then investigate via the coupling matrix whether the test statistic $\hat{t}$ focused on the right classes in the reference data via $W_{0,1}$. More concretely, we can sum over the columns (the test instances) of $W_{0,1}$ and check which reference instances obtained the highest weights.
We embed the news posts using pre-trained embeddings and optionally add a dimensionality reduction step with . UMAP also allows to leverage reference data labels.
The below figure of the of a random sample from the uniform distribution U[0,1] against the obtained p-values from the vanilla and context-aware MMD detectors illustrate how well both detectors are calibrated. A perfectly calibrated detector should have a Q-Q plot which closely follows the diagonal. Only the middle plot in the grid shows the detector's p-values. The other plots correspond to n_runs p-values actually sampled from U[0,1] to contextualise how well the central plot follows the diagonal given the limited number of samples.
The drift detector applies feature-wise two-sample Kolmogorov-Smirnov (K-S) tests. For multivariate data, the obtained p-values for each feature are aggregated either via the Bonferroni or the False Discovery Rate (FDR) correction. The Bonferroni correction is more conservative and controls for the probability of at least one false positive. The FDR correction on the other hand allows for an expected fraction of false positives to occur.
For high-dimensional data, we typically want to reduce the dimensionality before computing the feature-wise univariate K-S tests and aggregating those via the chosen correction method. Following suggestions in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift, we incorporate Untrained AutoEncoders (UAE) and black-box shift detection using the classifier's softmax outputs (BBSDs) as out-of-the box preprocessing methods and note that PCA can also be easily implemented using scikit-learn
. Preprocessing methods which do not rely on the classifier will usually pick up drift in the input data, while BBSDs focuses on label shift. The adversarial detector which is part of the library can also be transformed into a drift detector picking up drift that reduces the performance of the classification model. We can therefore combine different preprocessing techniques to figure out if there is drift which hurts the model performance, and whether this drift can be classified as input drift or label shift.
The method works with both the PyTorch and TensorFlow frameworks for the optional preprocessing step. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
Original CIFAR-10 data:
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
Let's pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the K-S test. We also split the corrupted data by corruption type:
We can visualise the same instance for each corruption type:
We can also verify that the performance of a classification model on CIFAR-10 drops significantly on this perturbed dataset:
Given the drop in performance, it is important that we detect the harmful data drift!
First we try a drift detector using the TensorFlow framework for the preprocessing step. We are trying to detect data drift on high-dimensional (32x32x3) data using feature-wise univariate tests. It therefore makes sense to apply dimensionality reduction first. Some dimensionality reduction methods also used in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift are readily available: a randomly initialized encoder (UAE or Untrained AutoEncoder in the paper), BBSDs (black-box shift detection using the classifier's softmax outputs) and PCA.
Random encoder
First we try the randomly initialized encoder:
The p-value used by the detector for the multivariate data with encoding_dim features is equal to p_val / encoding_dim because of the Bonferroni correction.
Let's check whether the detector thinks drift occurred on the different test sets and time the prediction calls:
As expected, drift was only detected on the corrupted datasets. The feature-wise p-values for each univariate K-S test per (encoded) feature before multivariate correction show that most of them are well above the $0.05$ threshold for H0 and below for the corrupted datasets.
BBSDs
For BBSDs, we use the classifier's softmax outputs for black-box shift detection. This method is based on Detecting and Correcting for Label Shift with Black Box Predictors. The ResNet classifier is trained on data standardised by instance so we need to rescale the data.
Now we initialize the detector. Here we use the output of the softmax layer to detect the drift, but other hidden layers can be extracted as well by setting 'layer' to the index of the desired hidden layer in the model:
Again we can see that the p-value used by the detector for the multivariate data with 10 features (number of CIFAR-10 classes) is equal to p_val / 10 because of the Bonferroni correction.
There is no drift on the original held out test set:
We can also check what happens when we introduce class imbalances between the reference data X_ref and the tested data X_imb. The reference data will use $75$% of the instances of the first 5 classes and only $25$% of the last 5. The data used for drift testing then uses respectively $25$% and $75$% of the test instances for the first and last 5 classes.
Update reference dataset for the detector and make predictions. Note that we store the preprocessed reference data since the preprocess_at_init
kwarg is by default True:
So far we have kept the reference data the same throughout the experiments. It is possible however that we want to test a new batch against the last N instances or against a batch of instances of fixed size where we give each instance we have seen up until now the same chance of being in the reference batch (reservoir sampling). The update_x_ref
argument allows you to change the reference data update rule. It is a Dict which takes as key the update rule ('last' for last N instances or 'reservoir_sampling') and as value the batch size N of the reference data. You can also save the detector after the prediction calls to save the updated reference data.
The reference data is now updated with each predict
call. Say we start with our imbalanced reference set and make a prediction on the remaining test set data X_imb, then the drift detector will figure out data drift has occurred.
We can now see that the reference data consists of N instances, obtained through reservoir sampling.
We then draw a random sample from the training set and compare it with the updated reference data. This still highlights that there is data drift but will update the reference data again:
When we draw a new sample from the training set, it highlights that it is not drifting anymore against the reservoir in X_ref.
Instead of the Bonferroni correction for multivariate data, we can also use the less conservative False Discovery Rate (FDR) correction. See here or here for nice explanations. While the Bonferroni correction controls the probability of at least one false positive, the FDR correction controls for an expected amount of false positives. The p_val
argument at initialisation time can be interpreted as the acceptable q-value when the FDR correction is applied.
We can leverage the adversarial scores obtained from an adversarial autoencoder trained on normal data and transform it into a data drift detector. The score function of the adversarial autoencoder becomes the preprocessing function for the drift detector. The K-S test is then a simple univariate test on the adversarial scores. Importantly, an adversarial drift detector flags malicious data drift. We can fetch the pretrained adversarial detector from a Google Cloud Bucket or train one from scratch:
Initialise the drift detector:
Make drift predictions on the original test set and corrupted data:
While X_imb clearly exhibits input data drift due to the introduced class imbalances, it is not flagged by the adversarial drift detector since the performance of the classifier is not affected and the drift is not malicious. We can visualise this by plotting the adversarial scores together with the harmfulness of the data corruption as reflected by the drop in classifier accuracy:
We can therefore use the scores of the detector itself to quantify the harmfulness of the drift! We can generalise this to all the corruptions at each severity level in CIFAR-10-C:
We now compute mean scores and standard deviations per severity level and plot the results. The plot shows the mean adversarial scores (lhs) and ResNet-32 accuracies (rhs) for increasing data corruption severity levels. Level 0 corresponds to the original test set. Harmful scores are scores from instances which have been flipped from the correct to an incorrect prediction because of the corruption. Not harmful means that the prediction was unchanged after the corruption.
Model distillation is a technique that is used to transfer knowledge from a large network to a smaller network. Typically, it consists of training a second model with a simplified architecture on soft targets (the output distributions or the logits) obtained from the original model.
Here, we apply model distillation to obtain harmfulness scores, by comparing the output distributions of the original model with the output distributions of the distilled model, in order to detect adversarial data, malicious data drift or data corruption. We use the following definition of harmful and harmless data points:
Harmful data points are defined as inputs for which the model's predictions on the uncorrupted data are correct while the model's predictions on the corrupted data are wrong.
Harmless data points are defined as inputs for which the model's predictions on the uncorrupted data are correct and the model's predictions on the corrupted data remain correct.
Analogously to the adversarial AE detector, which is also part of the library, the model distillation detector picks up drift that reduces the performance of the classification model.
Moreover, in this example a drift detector that applies two-sample Kolmogorov-Smirnov (K-S) tests to the scores is employed. The p-values obtained are used to assess the harmfulness of the data.
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance.
Original CIFAR-10 data:
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
Let's pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
We split the corrupted data by corruption type:
We can visualise the same instance for each corruption type:
We can also verify that the performance of a classification model on CIFAR-10 drops significantly on this perturbed dataset:
Analogously to the adversarial AE detector, which uses an autoencoder to reproduce the output distribution of a classifier and produce adversarial scores, the model distillation detector achieves the same goal by using a simple classifier in place of the autoencoder. This approach is more flexible since it bypasses the instance's generation step, and it can be applied in a straightforward way to a variety of data sets such as text or time series.
We can use the adversarial scores produced by the Model Distillation detector in the context of drift detection. The score function of the detector becomes the preprocessing function for the drift detector. The K-S test is then a simple univariate test between the adversarial scores of the reference batch and the test data. Higher adversarial scores indicate more harmful drift. Importantly, a harmfulness detector flags malicious data drift. We can fetch the pretrained model distillation detector from a Google Cloud Bucket or train one from scratch:
Definition and training of the distilled model
Scores and p-values calculation
Here we initialize the K-S drift detector using the harmfulness scores as a preprocessing function. The KS test is performed on these scores.
Initialise the drift detector:
Calculate scores. We split the corrupted data into harmful and harmless data and visualize the harmfulness scores for various values of corruption severity.
Plot scores
We now plot the mean scores and standard deviations per severity level. The plot shows the mean harmfulness scores (lhs) and ResNet-32 accuracies (rhs) for increasing data corruption severity levels. Level 0 corresponds to the original test set. Harmful scores are scores from instances which have been flipped from the correct to an incorrect prediction because of the corruption. Not harmful means that a correct prediction was unchanged after the corruption.
Plot p-values for contaminated batches
In order to simulate a realistic scenario, we perform a K-S test on batches of instance which are increasingly contaminated with corrupted data. The following steps are implemented:
We randomly pick n_ref=1000
samples from the non-currupted test set to be used as a reference set in the initialization of the K-S drift detector.
We sample batches of data of size batch_size=100
contaminated with an increasing number of harmful corrupted data and harmless corrupted data.
The K-S detector predicts whether drift occurs between the contaminated batches and the reference data and returns the p-values of the test.
We observe that contamination of the batches with harmful data reduces the p-values much faster than contamination with harmless data. In the latter case, the p-values remain above the detection threshold even when the batch is heavily contaminated
We repeat the test for 100 randomly sampled batches and we plot the mean and the maximum p-values for each level of severity and contamination below. We can see from the plot that the detector is able to clearly detect a batch contaminated with harmful data compared to a batch contaminated with harmless data when the percentage of currupted data reaches 20%-30%.
The drift detector applies feature-wise two-sample Kolmogorov-Smirnov (K-S) tests for the continuous numerical features and Chi-Squared tests for the categorical features. For multivariate data, the obtained p-values for each feature are aggregated either via the Bonferroni or the False Discovery Rate (FDR) correction. The Bonferroni correction is more conservative and controls for the probability of at least one false positive. The FDR correction on the other hand allows for an expected fraction of false positives to occur.
The instances contain a person's characteristics like age, marital status or education while the label represents whether the person makes more or less than $50k per year. The dataset consists of a mixture of numerical and categorical features. It is fetched using the Alibi library, which can be installed with pip:
The fetch_adult
function returns a Bunch
object containing the instances, the targets, the feature names and a dictionary with as keys the column indices of the categorical features and as values the possible categories for each categorical variable.
We split the data in a reference set and 2 test sets on which we test the data drift:
We need to provide the drift detector with the columns which contain categorical features so it knows which features require the Chi-Squared and which ones require the K-S univariate test. We can either provide a dict with as keys the column indices and as values the number of possible categories or just set the values to None and let the detector infer the number of categories from the reference data as in the example below:
Initialize the detector:
We can also save/load an initialised detector:
Now we can check whether the 2 test sets are drifting from the reference data:
Let's take a closer look at each of the features. The preds
dictionary also returns the K-S or Chi-Squared test statistics and p-value for each feature:
None of the feature-level p-values are below the threshold:
If you are interested in individual feature-wise drift, this is also possible:
What about the second test set?
We can again investigate the individual features:
It seems like there is little divergence in the distributions of the features between the reference and test set. Let's visualize this:
While the TabularDrift detector works fine with numerical or categorical features only, we can also directly use a categorical drift detector. In this case, we don't need to specify the categorical feature columns. First we construct a categorical-only dataset and then use the ChiSquareDrift detector:
A number of convenient and powerful kernel-based drift detectors such as the MMD detector (Gretton et al., 2012) or the learned kernel MMD detector (Liu et al., 2020) do not scale favourably with increasing dataset size $n$, leading to quadratic complexity $\mathcal{O}(n^2)$ for naive implementations. As a result, we can quickly run into memory issues by having to store the $[N_\text{ref} + N_\text{test}, N_\text{ref} + N_\text{test}]$ kernel matrix (on the GPU if applicable) used for an efficient implementation of the permutation test. Note that $N_\text{ref}$ is the reference data size and $N_\text{test}$ the test data size.
We can however drastically speed up and scale up kernel-based drift detectors to large dataset sizes by working with symbolic kernel matrices instead and leverage the KeOps library to do so. For the user of $\texttt{Alibi Detect}$ the only thing that changes is the specification of the detector's backend, e.g. for the MMD detector:
In this notebook we will run a few simple benchmarks to illustrate the speed and memory improvements from using KeOps over vanilla PyTorch on the GPU (1x RTX 2080 Ti) for both the standard MMD and learned kernel MMD detectors.
We randomly sample points from the standard normal distribution and run the detectors with PyTorch and KeOps backends for the following settings:
$N_\text{ref}, N_\text{test} = [2, 5, 10, 20, 50, 100]$ (batch sizes in '000s)
$D = [2, 10, 50]$
Where $D$ denotes the number of features.
The notebook requires PyTorch and KeOps to be installed. Once PyTorch is installed, KeOps can be installed via pip:
Before we start let’s fix the random seeds for reproducibility:
First we define some utility functions to run the experiments:
As detailed earlier, we will compare the PyTorch with the KeOps implementation of the MMD and learned kernel MMD detectors for a variety of reference and test data batch sizes as well as different feature dimensions. Note that for the PyTorch implementation, the portion of the kernel matrix for the reference data itself can already be computed at initialisation of the detector. This computation will not be included when we record the detector's prediction time. Since use cases where $N_\text{ref} >> N_\text{test}$ are quite common, we will also test for this specific setting. The key reason is that we cannot amortise this computation for the KeOps detector since we are working with lazily evaluated symbolic matrices.
1. $N_\text{ref} = N_\text{test}$
Note that for KeOps we could further increase the number of instances in the reference and test sets (e.g. to 500,000) without running into memory issues.
Below we visualise the runtimes of the different experiments. We can make the following observations:
The relative speed improvements of KeOps over vanilla PyTorch increase with increasing batch size.
Due to the explicit kernel computation and storage, the PyTorch detector runs out-of-memory after a little over 10,000 instances in each of the reference and test sets while KeOps keeps scaling up without any issues.
The relative speed improvements decline with growing feature dimension. Note however that we would not recommend using a (untrained) MMD detector on very high-dimensional data in the first place.
The plots show both the absolute and relative (PyTorch / KeOps) mean prediction times for the MMD drift detector for different feature dimensions $[2, 10, 50]$.
The difference between KeOps and PyTorch is even more striking when we only look at $[2, 10]$ features:
2. $N_\text{ref} >> N_\text{test}$
Now we check whether the speed improvements still hold when $N_\text{ref} >> N_\text{test}$ ($N_\text{ref} / N_\text{test} = 10$) and a large part of the kernel can already be computed at initialisation time of the PyTorch (but not the KeOps) detector.
The below plots illustrate that KeOps indeed still provides large speed ups over PyTorch. The x-axis shows the reference batch size $N_\text{ref}$. Note that $N_\text{ref} / N_\text{test} = 10$.
We conduct similar experiments as for the MMD detector for $N_\text{ref} = N_\text{test}$ and n_features=50
. We use a deep learned kernel with an MLP followed by Gaussian RBF kernels and project the input features on a d_out=2
-dimensional space. Since the learned kernel detector computes the kernel matrix in a batch-wise manner, we can also scale up the number of instances for the PyTorch backend without running out-of-memory.
We again plot the absolute and relative (PyTorch / KeOps) mean prediction times for the learned kernel MMD drift detector for different feature dimensions:
As illustrated in the experiments, KeOps allows you to drastically speed up and scale up drift detection to larger datasets without running into memory issues. The speed benefit of KeOps over the PyTorch (or TensorFlow) MMD detectors decrease as the number of features increases. Note though that it is not advised to apply the (untrained) MMD detector to very high-dimensional data in the first place and that we can apply dimensionality reduction via the deep kernel for the learned kernel MMD detector.
Under the hood, drift detectors leverage a function (also known as a test-statistic) that is expected to take a large value if drift has occurred and a low value if not. The power of the detector is partly determined by how well the function satisfies this property. However, specifying such a function in advance can be very difficult.
The classifier-based drift detector simply tries to correctly distinguish instances from the reference data vs. the test set. The classifier is trained to output the probability that a given instance belongs to the test set. If the probabilities it assigns to unseen tests instances are significantly higher (as determined by a Kolmogorov-Smirnov test) than those it assigns to unseen reference instances then the test set must differ from the reference set and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used for the significance test. Note that a new classifier is trained for each test set or even each fold within the test set.
The method works with both the PyTorch, TensorFlow, and Sklearn frameworks. We will focus exclusively on the Sklearn backend in this notebook.
Adult dataset consists of 32,561 distributed over 2 classes based on whether the annual income is >50K. We evaluate drift on particular subsets of the data which are constructed based on the education level. As we will further discuss, our reference dataset will consist of people having a low education level, while our test dataset will consist of people having a high education level.
Note: we need to install alibi
to fetch the adult
dataset.
We split the dataset in two based on the education level. We define a low_education
level consisting of: 'Dropout'
, 'High School grad'
, 'Bachelors'
, and a high_education
level consisting of: 'Bachelors'
, 'Masters'
, 'Doctorate'
. Intentionally we included an overlap between the two distributions consisting of people that have a Bachelors
degree. Our goal is to detect that the two distributions are different.
We sample our reference dataset from the low_education
level. In addition, we sample two other datasets:
x_h0
- sampled from the low_education
level to support the null hypothesis (i.e., the two distributions are identical);
x_h1
- sampled from the high_education
level to support the alternative hypothesis (i.e., the two distributions are different);
We perform a binomial test using a RandomForestClassifier
.
As expected, when testing against x_h0
, we fail to reject $H_0$, while for the second case there is enough evidence to reject $H_0$ and flag that the data has drifted.
For the classifiers that do not support predict_proba
but offer support for decision_function
, we can perform a K-S test on the scores by setting preds_type='scores'
.
Some models can return a poor estimate of the class label probability or some might not even support probability predictions. We can add calibration on top of each classifier to obtain better probability estimates and perform a K-S test. For demonstrative purposes, we will calibrate a LinearSVC
which does not support predict_proba
, but any other classifier would work.
In order to use the entire dataset and obtain unbiased predictions required to perform the statistical test, the ClassifierDrift
detector has the option to perform a n_folds
split. Although appealing due to its data efficiency, this method can be slow since it is required to train a number of n_folds
classifiers.
For the RandomForestClassifier
we can avoid retraining n_folds
classifiers by using the out-of-bag predictions. In a RandomForestClassifier
each tree is trained on a separate dataset obtained by sampling with replacement the original training set, a method known as bagging. On average, only 63% unique samples from the original dataset are used to train each tree (Bostrom). Thus, for each tree, we can obtain predictions for the remaining out-of-bag samples (i.e., the rest of 37%). By cumulating the out-of-bag predictions across all the trees we can eventually obtain a prediction for each sample in the original dataset. Note that we used the word 'eventually' because if the number of trees is too small, covering the entire original dataset might be unlikely.
For demonstrative purposes, we will compare the running time of the ClassifierDrift
detector when using a RandomForestClassifier
in two setups: n_folds=5, use_oob=False
and use_oob=True
.
We can observe that in this particular setting, using the out-of-bag prediction can speed up the procedure up to almost x4.
We illustrate drift detection on molecular graphs using a variety of detectors:
Kolmogorov-Smirnov detector on the output of the binary classification Graph Isomorphism Network to detect prediction distribution shift.
Model Uncertainty detector which leverages a measure of uncertainty on the model predictions (in this case MC dropout) to detect drift which could lead to degradation of model performance.
Maximum Mean Discrepancy detector on graph embeddings to flag drift in the input data.
Learned Kernel detector which flags drift in the input data using a (deep) learned kernel. The method trains a (deep) kernel on part of the data to maximise an estimate of the test power. Once the kernel is learned a permutation test is performed in the usual way on the value of the Maximum Mean Discrepancy (MMD) on the held out test set.
Kolmogorov-Smirnov detector to see if drift occurred on graph level statistics such as the number of nodes, edges and the average clustering coefficient.
We will train a classification model and detect drift on the ogbg-molhiv dataset. The dataset contains molecular graphs with both atom features (atomic number-1, chirality, node degree, formal charge, number of H bonds, number of radical electrons, hybridization, aromatic?, in a ring?) and bond level properties (bond type (e.g. single or double), bond stereo code, conjugated?). The goal is to predict whether a molecule inhibits HIV virus replication or not, so the task is binary classification.
The dataset is split using the scaffold splitting procedure. This means that the molecules are split based on their 2D structural framework. Structurally different molecules are grouped into different subsets (train, validation, test) which could mean that there is drift between the splits.
The dataset is retrieved from the Open Graph Benchmark dataset collection.
Besides alibi-detect
, this example notebook also uses PyTorch Geometric and OGB, both of which can be installed via pip/conda.
We set some samples apart to serve as the reference data for our drift detectors. Note that the allowed format of the reference data is very flexible and can be np.ndarray
or List[Any]
:
Let's plot some graph summary statistics such as the distribution of the node degrees, number of nodes and edges as well as the clustering coefficients:
While the average number of nodes and edges are similar across the splits, the histograms show that the tails are slightly heavier for the training graphs.
We borrow code from the PyTorch Geometric GNN explanation example to visualize molecules from the graph objects.
As our classifier we use a variation of a Graph Isomorphism Network incorporating edge (bond) as well as node (atom) features.
Train and evaluate the model. Evaluation is done using ROC-AUC. If you already have a trained model saved, you can directly load it by specifying the load_path
:
We will first detect drift on the prediction distribution of the GIN model. Since the binary classification model returns continuous numerical univariate predictions, we use the Kolmogorov-Smirnov drift detector. First we define some utility functions:
Because we pass lists with torch_geometric.data.Data
objects to the detector, we need to preprocess the data using the batch_fn
into torch_geometric.data.Batch
objects which can be fed to the model. Then we detect drift on the model prediction distribution.
Since the dataset is heavily imbalanced, we will test the detectors on a sample which oversamples from the minority class (molecules which inhibit HIV virus replication):
As expected, prediction distribution shift is detected for the imbalanced sample but not for the random test sample with similar label distribution as the reference data.
The model uncertainty drift detector can pick up when the model predictions drift into areas of changed uncertainty compared to the reference data. This can be a good proxy for drift which results in model performance degradation. The uncertainty is estimated via a Monte Carlo estimate (MC dropout). We use the RegressorUncertaintyDrift detector since our binary classification model returns 1D logits.
Although we didn't pick up drift in the GIN model prediction distribution for the test sample, we can see that the model is less certain about the predictions on the test set, illustrated by the lower ROC-AUC.
We can also more detect drift on the input data by encoding the data with a randomly initialized GNN to extract graph embeddings. Then we apply our detector of choice, e.g. the MMD detector on the extracted embeddings.
Instead of applying the MMD detector on the pooling output of a randomly initialized GNN encoder, we use the Learned Kernel detector which trains the encoder and kernel on part of the data to maximise an estimate of the detector's test power. Once the kernel is learned a permutation test is performed in the usual way on the value of the MMD on the held out test set.
Since the molecular scaffolds are different across the train, validation and test sets, we expect that this type of data shift is picked up in the input data (technically not the input but the graph embedding).
We could also compute graph-level statistics such as the number of nodes, edges and clustering coefficient and detect drift on those statistics using the Kolmogorov-Smirnov test with multivariate correction (e.g. Bonferroni). First we define a preprocessing step to extract the summary statistics from the graphs:
The 3 returned p-values correspond to respectively the p-values for the number of nodes, edges and clustering coefficient. We already saw in the EDA that the distributions of the node, edge and clustering coefficients look similar across the train, validation and test sets except for the tails. This is confirmed by running the drift detector on the graph statistics which cannot seem to pick up on the differences in molecular scaffolds between the datasets, unless we heavily oversample from the minority class where the number of nodes and edges but not the clustering coefficient significantly differ.
We illustrate drift detection on text data using the following detectors:
Maximum Mean Discrepancy (MMD) detector using pre-trained transformers to flag drift in the embedding space.
Classifier drift detector to detect drift in the input space.
The Amazon dataset contains product reviews with a star rating. We will test whether drift can be detected if the ratings start to drift. For more information, check the WILDS documentation page.
Besides alibi-detect
, this example notebook also uses the Amazon dataset through the WILDS package. WILDS is a curated collection of benchmark datasets that represent distribution shifts faced in the wild and can be installed via pip
:
Throughout the notebook we use detectors with both PyTorch
and TensorFlow
backends.
We first load the dataset and create reference data, data which should not be rejected under the null of the test (H0) and data which should exhibit drift (H1). The drift is introduced later by specifying a specific star rating for the test instances.
The following cell will download the Amazon dataset (if DOWNLOAD=True). The download size is ~7GB and size on disk is ~7GB.
First we embed instances using a pretrained transformer model and detect data drift using the MMD detector on the embeddings.
Helper functions:
Define the transformer embedding preprocessing step:
Define a function which will for a specified number of iterations (n_sample
):
Configure the MMDDrift
detector with a new reference data sample
Detect drift on the H0 and H1 splits
Now we will use the ClassifierDrift detector which uses a binary classification model to try and distinguish the reference from the test (H0 or H1) data. Drift is then detected on the difference between the prediction distributions on out-of-fold reference vs. test instances using a Kolmogorov-Smirnov 2 sample test on the prediction probabilities or via a binomial test on the binarized predictions. We use a pretrained transformer model but freeze its weights and only train the head which consists of 2 dense layers with a leaky ReLU non-linearity:
We can do the same using TensorFlow instead of PyTorch as backend. We first define the classifier again and then simply run the detector:
Model-uncertainty drift detectors aim to directly detect drift that's likely to effect the performance of a model of interest. The approach is to test for change in the number of instances falling into regions of the input space on which the model is uncertain in its predictions. For each instance in the reference set the detector obtains the model's prediction and some associated notion of uncertainty. For example for a classifier this may be the entropy of the predicted label probabilities or for a regressor with dropout layers dropout Monte Carlo can be used to provide a notion of uncertainty. The same is done for the test set and if significant differences in uncertainty are detected (via a Kolmogorov-Smirnoff test) then drift is flagged.
It is important that the detector uses a reference set that is disjoint from the model's training set (on which the model's confidence may be higher).
For models that require batch evaluation both PyTorch and TensorFlow frameworks are supported. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
We start by demonstrating how to leverage model uncertainty to detect malicious drift when the model of interest is a classifer.
Dataset
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
Original CIFAR-10 data:
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
Let's pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
We split the original test set in a reference dataset and a dataset which should not be rejected under the no-change null H0. We also split the corrupted data by corruption type:
We can visualise the same instance for each corruption type:
We can also verify that the performance of a classification model on CIFAR-10 drops significantly on this perturbed dataset:
Given the drop in performance, it is important that we detect the harmful data drift!
Detect drift
Unlike many other approaches we needn't specify a dimension-reducing preprocessing step as the detector operates directly on the data as it is input to the model of interest. In fact, the two-stage projection input -> prediction -> uncertainty can be thought of as the projection from the input space onto the real line, ready to perform the test.
We simply pass the model to the detector and inform it that the predictions should be interpreted as 'probs' rather than 'logits' (i.e. a softmax has already been applied). By default uncertainty_type='entropy'
is used as the notion of uncertainty for classifier predictions, however uncertainty_type='margin'
can be specified to deem the classifier's prediction uncertain if they fall within a margin (e.g. in [0.45,0.55] for binary classifier probabilities) (similar to Sethi and Kantardzic (2017)).
Let's check whether the detector thinks drift occurred on the different test sets and time the prediction calls:
Note here how drift is only detected for the corrupted datasets on which the model's performance is significantly degraded. For the 'brightness' corruption, for which the model maintains 89% classification accuracy, the change in model uncertainty is not deemed significant (p-value 0.11, above the 0.05 threshold). For the other corruptions which signficiantly hamper model performance, the malicious drift is detected.
We now demonstrate how to leverage model uncertainty to detect malicious drift when the model of interest is a regressor. This is a less general approach as regressors often make point-predictions with no associated notion of uncertainty. However, if the model makes its predictions by ensembling the predicitons of sub-models then we can consider the variation in the sub-model predictions as a notion of uncertainty. RegressorUncertaintyDetector
facilitates models that output a vector of such sub-model predictions (uncertainty_type='ensemble'
) or deep learning models that include dropout layers and can therefore (as noted by Gal and Ghahramani 2016) be considered as an ensemble (uncertainty_type='mc_dropout'
, the default option).
Dataset
The Wine Quality Data Set consists of 1599 and 4898 samples of red and white wine respectively. Each sample has an associated quality (as determined by experts) and 11 numeric features indicating its acidity, density, pH etc. We consider the regression problem of tring to predict the quality of red wine sample given these features. We will then consider whether the model remains suitable for predicting the quality of white wine samples or whether the associated change in the underlying distribution should be considered as malicious drift.
First we load in the data.
We can see that the data for both red and white wine samples take the same format.
We shuffle and normalise the data such that each feature takes a value in [0,1], as does the quality we seek to predict.
We split the red wine data into a set on which to train the model, a reference set with which to instantiate the detector and a set which the detector should not flag drift. We then instantiate a DataLoader to pass the training data to a PyTorch model in batches.
Regression model
We now define the regression model that we'll train to predict the quality from the features. The exact details aren't important other than the presence of at least one dropout layer. We then train the model for 20 epochs to optimise the mean square error on the training data.
We now evaluate the trained model on both unseen samples of red wine and white wine. We see that, unsurprisingly, the model is better able to predict the quality of unseen red wine samples.
Detect drift
We now look at whether a regressor-uncertainty detector would have picked up on this malicious drift. We instantiate the detector and obtain drift predictions on both the held-out red-wine samples and the white-wine samples. We specify uncertainty_type='mc_dropout'
in this case, but alternatively we could have trained an ensemble model that for each instance outputs a vector of multiple independent predictions and specified uncertainty_type='ensemble'
.
Under the hood drift detectors leverage a function (also known as a test-statistic) that is expected to take a large value if drift has occurred and a low value if not. The power of the detector is partly determined by how well the function satisfies this property. However, specifying such a function in advance can be very difficult. In this example notebook we consider two ways in which a portion of the available data may be used to learn such a function before then applying it on the held out portion of the data to test for drift.
The classifier-based drift detector simply tries to correctly distinguish instances from the reference data vs. the test set. The classifier is trained to output the probability that a given instance belongs to the test set. If the probabilities it assigns to unseen tests instances are significantly higher (as determined by a Kolmogorov-Smirnov test) to those it assigns to unseen reference instances then the test set must differ from the reference set and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used for the significance test. Note that a new classifier is trained for each test set or even each fold within the test set.
The method works with both the PyTorch and TensorFlow frameworks. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
Original CIFAR-10 data:
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
Let's pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
We split the original test set in a reference dataset and a dataset which should not be flagged as drift. We also split the corrupted data by corruption type:
We can visualise the same instance for each corruption type:
Single fold
We use a simple classification model and try to distinguish between the reference data and the corrupted test sets. The detector defaults to binarize=False
which means a Kolmogorov-Smirnov test will be used to test for significant disparity between continuous model predictions (e.g. probabilities or logits). Initially we'll test at a significance level of $p=0.05$, use $75$% of the shuffled reference and test data for training and evaluate the detector on the remaining $25$%. We only train for 1 epoch.
If needed, the detector can be saved and loaded with save_detector
and load_detector
:
Let's check whether the detector thinks drift occurred on the different test sets and time the prediction calls:
As expected, drift was only detected on the corrupted datasets and the classifier could easily distinguish the corrupted from the reference data.
Use all the available data via cross-validation
So far we've only used $25$% of the data to detect the drift since $75$% is used for training purposes. At the cost of additional training time we can however leverage all the data via stratified cross-validation. We just need to set the number of folds and keep everything else the same. So for each test set n_folds
models are trained, and the out-of-fold predictions combined for the significance test:
An alternative to training a classifier to output high probabilities for instances from the test window and low probabilities for instances from the reference window is to learn a kernel that outputs high similarities between instances from the same window and low similarities between instances from different windows. The kernel may then be used within an MMD-test for drift. Liu et al. (2020) propose this learned approach and note that it is in fact a generalisation of the above classifier-based method. However, in this case we can train the kernel to directly optimise an estimate of the detector's power, which can result in superior performance.
This can be implemented as shown below. We use Pytorch instead of TensorFlow this time for the sake of variety. Because we are dealing with images we give our projection $\Phi$ a convolutional architecture.
We may then specify a DeepKernel
in the following manner. By default GaussianRBF
kernels are used for $k_a$ and $k_b$ and here we specify $\epsilon=0.01$, but we could alternatively set eps='trainable'
.
Since our PyTorch encoder expects the images in a (batch size, channels, height, width) format, we transpose the data. Note that this step could also be passed to the drift detector via the preprocess_fn
kwarg:
We then pass the kernel to the LearnedKernelDrift
detector. By default $75%$ of the data is used to train the kernel and the MMD-test is performed on the other $25%$.
Again, the detector can be saved and loaded:
Finally, lets make some predictions with the detector:
This notebook demonstrates a typical workflow for applying online drift detectors to streams of image data. For those unfamiliar with how the online drift detectors operate in alibi_detect
we recommend first checking out the more introductory example Online Drift Detection on the Wine Quality Dataset where online drift detection is performed for the wine quality dataset.
This notebook requires the wilds
, torch
and torchvision
packages which can be installed via pip
:
We will use the Camelyon17 dataset, one of the WILDS datasets of Koh et al, (2020) that represent "in-the-wild" distribution shifts for various data modalities. It contains tissue scans to be classificatied as benign or cancerous. The pre-change distribution corresponds to scans from across three hospitals and the post-change distribution corresponds to scans from a new fourth hospital.
Koh et al, (2020) show that models trained on scans from the pre-change distribution achieve an accuracy of 93.2% on unseen scans from same distribution, but only 70.3% accuracy on scans from the post-change distribution.
First we create a function that converts the Camelyon dataset to a stream in order to simulate a live deployment environment. We extract N instances to act as the reference set on which a model of interest was trained. We then consider a stream of images from the pre-change (same) distribution and a stream of images from the post-change (drifted) distribution.
The following cell will download the Camelyon dataset (if DOWNLOAD=True). The download size is ~10GB and size on disk is ~15GB.
Shown below are samples from the pre-change distribution:
And samples from the post-change distribution:
The images are of dimension 96x96x3. We train an autoencoder in order to define a more structured representational space of lower dimension. This projection can be thought of as an extension of the kernel. It is important that trained preprocessing components are trained on a split of data that doesn't then form part of the reference data passed to the drift detector.
We can train the autoencoder using a helper function provided for convenience in alibi-detect
.
The preprocessing/projection functions are expected to map numpy arrays to numpy array, so we wrap the encoder within the function below.
alibi-detect
's online drift detectors window the stream of data in an 'overlapping window' manner such that a test is performed at every time step. We will use an estimator of MMD as the test statistic. The estimate is updated incrementally at low cost. The thresholds are configured via simulation in an initial configuration phase to target the desired expected runtime (ERT) in the absence of change. For a detailed description of this calibration procedure see Cobb et al, 2021.
We define a function which will apply the detector to the streams and return the time at which drift was detected.
First we apply the detector multiple times to the pre-change stream where the distribution is unchanged.
We see that the average runtime in the absence of change is close to the desired ERT, as expected. We can inspect the detector's test_stats
and thresholds
properties to see how the test statistic varied over time and how close it got to exceeding the threshold.
Now we apply it to the post-change stream where the images are from a drifted distribution.
We see that the detector is quick to flag drift when it has occured.
In the context of deployed models, data (model queries) usually arrive sequentially and we wish to detect it as soon as possible after its occurence. One approach is to perform a test for drift every $W$ time-steps, using the $W$ samples that have arrived since the last test. Such a strategy could be implemented using any of the offline detectors implemented in alibi-detect
, but being both sensitive to slight drift and responsive to severe drift is difficult. If the window size $W$ is too small then slight drift will be undetectable. If it is too large then the delay between test-points hampers responsiveness to severe drift.
An alternative strategy is to perform a test each time data arrives. However the usual offline methods are not applicable because the process for computing p-values is too expensive and doesn't account for correlated test outcomes when using overlapping windows of test data.
Online detectors instead work by computing the test-statistic once using the first $W$ data points and then updating the test-statistic sequentially at low cost. When no drift has occured the test-statistic fluctuates around its expected value and once drift occurs the test-statistic starts to drift upwards. When it exceeds some preconfigured threshold value, drift is detected.
Unlike offline detectors which require the specification of a threshold p-value (a false positive rate), the online detectors in alibi-detect
require the specification of an expected run-time (ERT) (an inverted FPR). This is the number of time-steps that we insist our detectors, on average, should run for in the absense of drift before making a false detection. Usually we would like the ERT to be large, however this results in insensitive detectors which are slow to respond when drift does occur. There is a tradeoff between the expected run time and the expected detection delay.
To target the desired ERT, thresholds are configured during an initial configuration phase via simulation. This configuration process is only suitable when the amount reference data (most likely the training data of the model of interest) is relatively large (ideally around an order of magnitude larger than the desired ERT). Configuration can be expensive (less so with a GPU) but allows the detector to operate at low-cost during deployment.
This notebook demonstrates online drift detection using two different two-sample distance metrics for the test-statistic, the maximum mean discrepency (MMD) and least-squared density difference (LSDD), both of which can be updated sequentially at low cost.
The online detectors are implemented in both the PyTorch and TensorFlow frameworks with support for CPU and GPU. Various preprocessing steps are also supported out-of-the box in Alibi Detect for both frameworks and an example will be given in this notebook. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
The Wine Quality Data Set consists of 4898 and 1599 samples of white and red wine respectively. Each sample has an associated quality (as determined by experts) and 11 numeric features indicating its acidity, density, pH etc. We consider the regression problem of tring to predict the quality of white wine samples given these features. We will then consider whether the model remains suitable for predicting the quality of red wine samples or whether the associated change in the underlying distribution should be considered as drift.
The Maximum Mean Discepency (MMD) is a distance-based measure between 2 distributions p and q based on the mean embeddings $\mu_{p}$ and $\mu_{q}$ in a reproducing kernel Hilbert space $F$:
Given reference samples ${X_i}{i=1}^{N}$ and test samples ${Y_i}{i=t}^{t+W}$ we may compute an unbiased estimate $\widehat{MMD}^2(F, {X_i}{i=1}^N, {Y_i}{i=t}^{t+W})$ of the squared MMD between the two underlying distributions. Depending on the size of the reference and test windows, $N$ and $W$ respectively, this can be relatively expensive. However, once computed it is possible to update the statistic to estimate to the squared MMD between the distributions underlying ${X_i}{i=1}^{N}$ and ${Y_i}{i=t+1}^{t+1+W}$ at a very low cost, making it suitable for online drift detection.
By default we use a radial basis function kernel, but users are free to pass their own kernel of preference to the detector.
First we load in the data:
We can see that the data for both red and white wine samples take the same format.
We shuffle and normalise the data such that each feature takes a value in [0,1], as does the quality we seek to predict. We assue that our model was trained on white wine samples, which therefore forms the reference distribution, and that red wine samples can be considered to be drawn from a drifted distribution.
Although it may not be necessary on this relatively low-dimensional data for which individual features are semantically meaningful, we demonstrate how principle component analysis (PCA) can be performed as a preprocessing stage to project raw data onto a lower dimensional representation which more concisely captures the factors of variation in the data. As not to bias the detector it is necessary to fit the projection using a split of the data which isn't then passed as reference data. We additionally split off some white wine samples to act as undrifted data during deployment.
Now we define a PCA object to be used as a preprocessing function to project the 11-D data onto a 2-D representation. We learn the first 2 principal components on the training split of the reference data.
Hopefully the learned preprocessing step has learned a projection such that in the lower dimensional space the two samples are distinguishable.
Now we can define our online drift detector. We specify an expected run-time (in the absence of drift) of 50 time-steps, and a window size of 10 time-steps. Upon initialising the detector thresholds will be computed using 2500 boostrap samples. These values of ert
, window_size
and n_bootstraps
are lower than a typical use-case in order to demonstrate the average behaviour of the detector over a large number of runs in a reasonable time.
We now define a function which will simulate a single run and return the run-time. Note how the detector acts on single instances at a time, the run-time is considered as the time elapsed after the test-window has been filled, and that the detector is stateful and must be reset between detections.
Now we look at the distribution of run-times when operating on the held-out data from the reference distribution of white wine samples. We report the average run-time, however note that the targeted run-time distribution, a Geometric distribution with mean ert
, is very high variance so the empirical average may not be that close to ert
over a relatively small number of runs. We can see that the detector accurately targets the desired Geometric distribution however by inspecting the linearity of a Q-Q plot.
If we run the detector in an identical manner but on data from the drifted distribution of red wine samples the average run-time is much lower.
We additionally show that TensorFlow can also be used as the backend and that sometimes it is not necessary to perform preprocessing, making definition of the drift detector simpler. Moreover, in the absence of a learned preprocessing stage we may use all of the reference data available.
And now we define the LSDD-based online drift detector, again with an ert
of 50 and window_size
of 10.
We run this new detector on the held out reference data and again see that in the absence of drift the distribution of run-times follows a Geometric distribution with mean ert
.
And when drift has occured the detector is very fast to respond.
When true outputs/labels are available, we can perform supervised drift detection; monitoring the model's performance directly in order to check for harmful drift. Two detectors ideal for this application are the Fisher’s Exact Test (FET) detector and Cramér-von Mises (CVM) detector detectors.
The FET detector is designed for use on binary data, such as the instance level performance indicators from a classifier (i.e. 0/1 for each incorrect/correct classification). The CVM detector is designed use on continuous data, such as a regressor's instance level loss or error scores.
In this example we will use the offline versions of these detectors, which are suitable for use on batches of data. In many cases data may arrive sequentially, and the user may wish to perform drift detection as the data arrives to ensure it is detected as soon as possible. In this case, the online versions of the FET and CVM detectors can be used, as will be explored in a future example.
The palmerpenguins dataset consists of data on 344 penguins from 3 islands in the Palmer Archipelago, Antarctica. There are 3 different species of penguin in the dataset, and a common task is to classify the the species of each penguin based upon two features, the length and depth of the peguin's bill, or beak.
Artwork by Allison Horst
This notebook requires the seaborn
package for visualization and the palmerpenguins
package to load data. Thse can be installed via pip
:
To download the dataset we use the palmerpenguins package:
The data consists of 333 rows (one row is removed as contains a NaN), one for each penguin. There are 8 features describing the peguins' physical characteristics, their species and sex, the island each resides on, and the year measurements were taken.
For our first example use case, we will perform the popular species classification task. Here we wish the classify the species
based on only bill_length_mm
and bill_depth_mm
. To start we remove the other features and visualise those that remain.
The above plot shows that the Adeilie species can primarily be identified by looking at bill length. Then to further distinguish between Gentoo and Chinstrap, we can look at the bill depth.
Next we separate the data into inputs and outputs, and encoder the species data to integers. Finally, we now split into three data sets; one to train the classifier, one to act a reference set when testing for drift, and one to test for drift on.
For this dataset, a relatively shallow decision tree classifier should be sufficient, and so we train an sklearn
one on the training data.
As expected, the decision tree is able to give acceptable classification accuracy on the train and test sets.
In order to demonstrate use of the drift detectors, we first need to add some artificial drift to the test data X_test
. We add two types of drift here; to create covariate drift we subtract 5mm from the bill length of all the Gentoo penguins. $P(y|\mathbf{X})$ is unchanged here, but clearly we have introduced a delta $\Delta P(\mathbf{X})$. To create concept drift, we switch the labels of the Gentoo and Chinstrap penguins, so that the underlying process $P(y|\mathbf{X})$ is changed.
We now define a utility function to plot the classifier's decision boundaries, and we use this to visualise the reference data set, the test set, and the two new data sets where drift is present.
These plots serve as a visualisation of the differences between covariate drift and concept drift. Importantly, the model accuracies shown above also highlight the fact that not all drift is necessarily malicious, in the sense that even relatively significant drift does not always lead to degradation in a model's performance indicators. For example, the model actually gives a slightly higher accuracy on the covariate drift data set than on the no drift set in this case. Conversely, the concept drift unsuprisingly leads to severely degraded model performance.
Before getting to the main task in this example, monitoring malicious drift with a supervised drift detector, we will first use the MMD detector to check for covariate drift. To do this we initialise it in an unsupervised manner by passing it the input data X_ref
.
Applying this detector on the no drift, covariate drift and concept drift data sets, we see that the detector only detects drift in the covariate drift case. Not detecting drift in the no drift case is desirable, but not detecting drift in the concept drift case is potentially problematic.
The fact that the unsupervised detector above does not detect the severe concept drift demonstrates the motivation for using supervised drift detectors that directly check for malicious drift, which can include malicious concept drift.
To perform supervised drift detection we first need to compute the model's performance indicators. Since this is a classification task, a suitable performance indicator is the instance level binary losses, which are computed below.
As seen above, these losses are binary data, where 0 represents an incorrect classification for each instance, and 1 represents a correct classification.
Since this is binary data, the FET detector is chosen, and initialised on the reference loss data. The alternative
hypothesis is set to less
, meaning we will only flag drift if the proportion of 1s to 0s is reduced compared to the reference data. In other words, we only flag drift if the model's performance has degraded.
Applying this detector to the same three data sets, we see that malicious drift isn't detected in the no drift or covariate drift cases, which is unsurprising since the model performance isn't degraded in these cases. However, with this supervised detector, we now detect the malicious concept drift as desired.
To provide a short example of supervised detection in a regression setting, we now rework the dataset into a regression task, and use the CVM detector on the model's squared error.
Warning: Must have scipy >= 1.7.0 installed for this example.
For a regression task, we take the penguins' flipper length and sex as inputs, and aim to predict the penguins' body mass. Looking at a scatter plot of these features, we can see there is substantial correlation between the chosen inputs and outputs.
Again, we split the dataset into the same three sets; a training set, reference set and test set.
This time we train a linear regressor on the training data, and find that it gives acceptable training and test accuracy.
To generate a copy of the test data with concept drift added, we use the model to create new output data, with a multiplicative factor and some Gaussian noise added. The quality of our synthetic output data is of course affected by the accuracy of the model, but it serves to demonstrate the behavior of the model (and detector) when $P(y|\mathbf{X})$ is changed.
Unsurprisingly, the covariate drift leads to degradation in the model accuracy.
As in the classification example, in order to perform supervised drift detection we need to compute the models performance indicators. For this regression example, the instance level squared errors are used.
The CVM detector is trained on the reference losses:
As desired, the CVM detector does not detect drift on the no drift data, but does on covariate drift data.
The Maximum Mean Discrepancy (MMD) detector is a kernel-based method for multivariate 2 sample testing. The MMD is a distance-based measure between 2 distributions p and q based on the mean embeddings $\mu_{p}$ and $\mu_{q}$ in a reproducing kernel Hilbert space $F$:
We can compute unbiased estimates of $MMD^2$ from the samples of the 2 distributions after applying the kernel trick. We use by default a radial basis function kernel, but users are free to pass their own kernel of preference to the detector. We obtain a $p$-value via a permutation test on the values of $MMD^2$. This method is also described in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift.
The method is implemented in both the PyTorch and TensorFlow frameworks with support for CPU and GPU. Various preprocessing steps are also supported out-of-the box in Alibi Detect for both frameworks and illustrated throughout the notebook. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes. We evaluate the drift detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in the classification model performance. We also check for drift against the original test set with class imbalances.
Original CIFAR-10 data:
For CIFAR-10-C, we can select from the following corruption types at 5 severity levels:
Let's pick a subset of the corruptions at corruption level 5. Each corruption type consists of perturbations on all of the original test set images.
We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the MMD test. We also split the corrupted data by corruption type:
We can visualise the same instance for each corruption type:
We can also verify that the performance of a classification model on CIFAR-10 drops significantly on this perturbed dataset:
Given the drop in performance, it is important that we detect the harmful data drift!
First we try a drift detector using the TensorFlow framework for both the preprocessing and the MMD computation steps.
We are trying to detect data drift on high-dimensional (32x32x3) data using a multivariate MMD permutation test. It therefore makes sense to apply dimensionality reduction first. Some dimensionality reduction methods also used in Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift are readily available: a randomly initialized encoder (UAE or Untrained AutoEncoder in the paper), BBSDs (black-box shift detection using the classifier's softmax outputs) and PCA (using scikit-learn
).
Random encoder
First we try the randomly initialized encoder:
Let's check whether the detector thinks drift occurred on the different test sets and time the prediction calls:
As expected, drift was only detected on the corrupted datasets.
BBSDs
For BBSDs, we use the classifier's softmax outputs for black-box shift detection. This method is based on Detecting and Correcting for Label Shift with Black Box Predictors. The ResNet classifier is trained on data standardised by instance so we need to rescale the data.
Initialisation of the drift detector. Here we use the output of the softmax layer to detect the drift, but other hidden layers can be extracted as well by setting 'layer' to the index of the desired hidden layer in the model:
Again drift is only flagged on the perturbed data.
We can do the same thing using the PyTorch backend. We illustrate this using the randomly initialized encoder as preprocessing step:
Since our PyTorch encoder expects the images in a (batch size, channels, height, width) format, we transpose the data:
The drift detector will attempt to use the GPU if available and otherwise falls back on the CPU. We can also explicitly specify the device. Let's compare the GPU speed up with the CPU implementation:
Notice the over 30x acceleration provided by the GPU.
Similar to the TensorFlow implementation, PyTorch can also use the hidden layer output from a pretrained model for the preprocessing step via:
Under the hood drift detectors leverage a function of the data that is expected to be large when drift has occured and small when it hasn't. In the Learned drift detectors on CIFAR-10 example notebook we note that we can learn a function satisfying this property by training a classifer to distinguish reference and test samples. However we now additionally note that if the classifier is specified in a certain way then when drift is detected we can inspect the weights of the classifier to shine light on exactly which features of the data were used to distinguish reference from test samples and therefore caused drift to be detected.
The SpotTheDiffDrift
detector is designed to make this process straightforward. Like the ClassifierDrift
detector, it uses a portion of the available data to train a classifier to discriminate between reference and test instances. Letting $\hat{p}_T(x)$ represent the probability assigned by the classifier that the instance $x$ is from the test set rather than reference set, the difference here is that we use a classifier of the form where $k(\cdot,\cdot)$ is a kernel specifying a notion of similarity between instances, $w_i$ are learnable test locations and $b_i$ are learnable regression coefficients.
The idea here is that if the detector flags drift and $b_i >0$ then we know that it reached its decision by considering how similar each instance is to the instance $w_i$, with those being more similar being more likely to be test instances than reference instances. Alternatively if $b_i < 0$ then instances more similar to $w_i$ were deemed more likely to be reference instances.
In order to provide less noisy and therefore more interpretable results, we define each test location as where $\bar{x}$ is the mean reference instance. We may then interpret $d_i$ as the additive transformation deemed to make the average reference more ($b_i>0$) or less ($b_i<0$) similar to a test instance. Defining the test locations in this way allows us to instead learn the difference $d_i$ and apply regularisation such that non-zero values must be justified by improved classification performance. This allows us to more clearly identify which features any detected drift should be attributed to.
This approach to interpretable drift detection is inspired by the work of Jitkrittum et al. (2016), however several major adaptations have been made.
The method works with both the PyTorch and TensorFlow frameworks. Alibi Detect does however not install PyTorch for you. Check the PyTorch docs how to do this.
We start with an image example in order to provide a visual illustration of how the detector works. For this prupose we use the MNIST dataset of 28 by 28 grayscale handwritten digits. To represent the common problem of new classes emerging during the deployment phase we consider a reference set of ~9,000 instances containing only the digits 1-9 and a test set of 10,000 instances containing all of the digits 0-9. We would like drift to be detected in this scenario because a model trained of the reference instances will not know how to process instances from the new class.
This notebook requires the torchvision
package which can be installed via pip
:
When instantiating the detector we should specify the number of "diffs" we would like it to use to discriminate reference from test instances. Here there is a trade off. Using n_diffs=1
is the simplest to interpret and seems to work well in practice. Using more diffs may result in stronger detection power but the diffs may be harder to interpret due to intereactions and conditional dependencies.
The strength of the regularisation (l1_reg
) to apply to the diffs should also be specified. Stronger regularisation results in sparser diffs as the classifier is encouraged to discriminate using fewer features. This may make the diff more interpretable but may again come at the cost of detection power.
We should also specify how the classifier should be trained with standard arguments such as learning_rate
, epochs
and batch_size
. By default a Gaussian RBF is used for the kernel but alternatives can be specified via the kernel
kwarg. Additionally the classifier can be initialised with any desired diffs by passing them with the initial_diffs
kwarg -- by default they are initialised with Gaussian noise with standard deviation equal to that observed in the reference data.
When we then call the detector to detect drift on the deployment/test set it trains the classifier (thereby learning the diffs) and the usual is_drift
and p_val
properties can be inspected in the usual way:
As expected, the drift was detected. However we may now additionally look at the learned diffs and corresponding coefficients to determine how the detector reached this decision.
The detector has identified the zero that was missing from the reference data -- it realised that test instances were on average more (coefficient > 0) simmilar to an instance with below average middle pixel values and above average zero-region pixel values than reference instances were. It used this information to determine that drift had occured.
To provide an example on tabular data we consider the Wine Quality Data Set consisting of 4898 and 1599 samples of white and red wine respectively. Each sample has an associated quality (as determined by experts) and 11 numeric features indicating its acidity, density, pH etc. To represent the problem of a model being trained on one distribution and deployed on a subtly different one, we take as a reference set the samples of white wine and consider the red wine samples to form a 'corrupted' deployment set.
We can see that the data for both red and white wine samples take the same format.
We extract the features and shuffle and normalise them such that they take values in [0,1].
We then split off half of the reference set to act as an unseen sample from the same underlying distribution for which drift should not be detected.
We instantiate our detector in the same way as we do above, but this time using the Pytorch backend for the sake of variety. We then get the predictions of the detector on both the undrifted and corrupted test sets.
As expected drift is detected on the red wine samples but not the held out white wine samples from the same distribution. Now we can inspect the returned diff to determine how the detector reached its decision
We see that the detector was able to discriminate the corrupted (red) wine samples from the reference (white) samples by noting that on average reference samples (coeff < 0) typically contain more sulfur dioxide and residual sugars but have less sulphates and chlorides and have lower pH and volatile and fixed acidity.
We detect drift on text data using both the and detectors. In this example notebook we will focus on detecting covariate shift $\Delta p(x)$ as detecting predicted label distribution drift does not differ from other modalities (check and drift on CIFAR-10).
It becomes however a little bit more involved when we want to pick up input data drift $\Delta p(x)$. When we deal with tabular or image data, we can either directly apply the two sample hypothesis test on the input or do the test after a preprocessing step with for instance a randomly initialized encoder as proposed in (they call it an Untrained AutoEncoder or UAE). It is not as straightforward when dealing with text, both in string or tokenized format as they don't directly represent the semantics of the input.
As a result, we extract (contextual) embeddings for the text and detect drift on those. This procedure has a significant impact on the type of drift we detect. Strictly speaking we are not detecting $\Delta p(x)$ anymore since the whole training procedure (objective function, training data etc) for the (pre)trained embeddings has an impact on the embeddings we extract.
The library contains functionality to leverage pre-trained embeddings from but also allows you to easily use your own embeddings of choice. Both options are illustrated with examples in this notebook.
Note
As is done in this example, it is recommended to pass text data to detectors as a list of strings (List[str]
). This allows for seamless integration with HuggingFace's transformers library.
One exception to the above is when custom embeddings are used. Here, it is important to ensure that the data is passed to the custom embedding model in a compatible format. In , a preprocess_batch_fn
is defined in order to convert list
's to the np.ndarray
's expected by the custom TensorFlow embedding.
The method works with both the PyTorch and TensorFlow frameworks for the statistical tests and preprocessing steps. Alibi Detect does however not install PyTorch for you. Check the how to do this.
Binary sentiment classification containing $25,000$ movie reviews for training and $25,000$ for testing. Install the nlp
library to fetch the dataset:
Let's take a look at respectively a negative and positive review:
We split the original test set in a reference dataset and a dataset which should not be rejected under the H0 of the statistical test. We also create imbalanced datasets and inject selected words in the reference set.
Reference, H0 and imbalanced data:
Inject words in reference data:
First we need to specify the type of embedding we want to extract from the BERT model. We can extract embeddings from the ...
pooler_output: Last layer hidden-state of the first token of the sequence (classification token; CLS) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pre-training. Note: this output is usually not a good summary of the semantic content of the input, you’re often better with averaging or pooling the sequence of hidden-states for the whole input sequence.
last_hidden_state: Sequence of hidden states at the output of the last layer of the model, averaged over the tokens.
hidden_state: Hidden states of the model at the output of each layer, averaged over the tokens.
hidden_state_cls: See hidden_state but use the CLS token output.
If hidden_state or hidden_state_cls is used as embedding type, you also need to pass the layer numbers used to extract the embedding from. As an example we extract embeddings from the last 8 hidden states.
Let's check what an embedding looks like:
So the BERT model's embedding space used by the drift detector consists of a $768$-dimensional vector for each instance. We will therefore first apply a dimensionality reduction step with an Untrained AutoEncoder (UAE) before conducting the statistical hypothesis test. We use the embedding model as the input for the UAE which then projects the embedding on a lower dimensional space.
Let's test this again:
Let’s first check if drift occurs on a similar sample from the training set as the reference data.
Detect drift on imbalanced and perturbed datasets:
H0:
Imbalanced data:
Perturbed data:
We can run the same detector with PyTorch backend for both the preprocessing step and MMD implementation:
H0:
Imbalanced data:
Perturbed data:
So far we used pre-trained embeddings from a BERT model. We can however also use embeddings from a model trained from scratch. First we define and train a simple classification model consisting of an embedding and LSTM layer in TensorFlow.
Load and tokenize data:
Let's check out an instance:
Define and train a simple model:
Extract the embedding layer from the trained model and combine with UAE preprocessing step:
Again, create reference, H0 and perturbed datasets. Also test against the Reuters news topic classification dataset.
H0:
Perturbed data:
The detector is not as sensitive as the Transformer-based K-S drift detector. The embeddings trained from scratch only trained on a small dataset and a simple model with cross-entropy loss function for 2 epochs. The pre-trained BERT model on the other hand captures semantics of the data better.
Sample from the Reuters dataset:
Any differentiable Pytorch or TensorFlow module that takes as input two instances and outputs a scalar (representing similarity) can be used as the kernel for this drift detector. However, in order to ensure that MMD=0 implies no-drift the kernel should satify a characteristic property. This can be guarenteed by defining a kernel as where $\Phi$ is a learnable projection, $k_a$ and $k_b$ are simple characteristic kernels (such as a Gaussian RBF, and $\epsilon>0$ is a small constant. By letting $\Phi$ be very flexible we can learn powerful kernels in this manner.
Here we address the same problem but using the least squares density difference (LSDD) as the two-sample distance in a manner similar to Bu et al. (2017). The LSDD between two distributions $p$ and $q$ on $\mathcal{X}$ is defined as and also has an empirical estimate $\widehat{LSDD}({X_i}{i=1}^N, {Y_i}{i=t}^{t+W})$ that can be updated at low cost as the test window is updated to ${Y_i}_{i=t+1}^{t+1+W}$.
We proceed to initialize the drift detector. From here on the detector works the same as for other modalities such as images. Please check the example or the for more information about each of the possible parameters.
Again check the example or the for more information about each of the possible parameters.