Context-aware drift detection on news articles

Introduction

In this notebook we show how to detect drift on text data given a specific context using the context-aware MMD detector (Cobb and Van Looveren, 2022). Consider the following simple example: the upcoming elections result in an increase of political news articles compared to other topics such as sports or science. Given the context (the elections), it is however not surprising that we observe this uptick. Moreover, assume we have a machine learning model which is trained to classify news topics, and this model performs well on political articles. So given that we fully expect this uptick to occur given the context, and that our model performs fine on the political news articles, we do not want to flag this type of drift in the data. This setting corresponds more closely to many real-life settings than traditional drift detection where we make the assumption that both the reference and test data are i.i.d. samples from their underlying distributions.

In our news topics example, each different topic such as politics, sports or weather represents a subpopulation of the data. Our context-aware drift detector can then detect changes in the data distribution which cannot be attributed to a change in the relative prevalences of these subpopulations, which we deem permissible. As a cherry on the cake, the context-aware detector allows you to understand which subpopulations are present in both the reference and test data. This allows you to obtain deep insights into the distribution underlying the test data.

Useful context (or conditioning) variables for the context-aware drift detector include but are not limited to:

  1. Domain or application specific contexts such as the time of day or the weather.

  2. Conditioning on the relative prevalences of known subpopulations, such as the frequency of political articles. It is important to note that while the relative frequency of each subpopulation might change, the distribution underlying each subpopulation cannot change.

  3. Conditioning on model predictions. Assume we trained a classifier which tries to figure out which news topic an article belongs to. Given our model predictions we then want to understand whether our test data follows the same underlying distribution as reference instances with similar model predictions. This conditioning would also be useful in case of trending news topics which cause the model prediction distribution to shift but not necessarily the distribution within each of the news topics.

  4. Conditioning on model uncertainties which would allow increases in model uncertainty due to drift into familiar regions of high aleatoric uncertainty (often fine) to be distinguished from that into unfamiliar regions of high epistemic uncertainty (often problematic).

The following settings will be illustrated throughout the notebook:

  1. A change in the prevalences of subpopulations (i.e. news topics) relative to their prevalences in the training data. Contrary to traditional drift detection approaches, the context-aware detector does not flag drift as this change in frequency of news topics is permissible given the context provided (e.g. more political news articles around elections).

  2. A change in the underlying distribution of one or more subpopulations takes place. While we allow changes in the prevalence of the subpopulations accounted for by the context variable, we do not allow changes of the subpopulations themselves. Let's assume that a newspaper usually has a certain tone (e.g. more conservative) when it comes to politics. If this tone changes (to less conservative) around elections (increased frequency of political news articles), then we want to flag it as drift since the change cannot be attributed to the context given to the detector.

  3. A change in the distribution as we observe a previously unseen news topic. A newspaper might for instance add a classified ads section, which was not present in the reference data.

Under setting 1. we want our detector to be well-calibrated (a controlled False Positive Rate (FPR) and more generally a p-value which is uniformly distributed between 0 and 1) while under settings 2. and 3. we want our detector to be powerful and flag the drift. Lastly, we show how the detector can help you to understand the connection between the reference and test data distributions better.

Data

We use the 20 newsgroup dataset which contains about 18,000 newsgroups post across 20 topics, including politics, science sports or religion.

Requirements

The notebook requires the umap-learn, torch, sentence-transformers, statsmodels, seaborn and datasets packages to be installed, which can be done via pip:

Before we start let's fix the random seeds for reproducibility:

Load data

First we load the data, show which classes (news topics) are present and what an instance looks like.

Let's take a look at an instance from the dataset:

Define models and train a classifier

We embed the news posts using SentenceTransformers pre-trained embeddings and optionally add a dimensionality reduction step with UMAP. UMAP also allows to leverage reference data labels.

We define respectively a generic clustering model using UMAP, a model to embed the text input using pre-trained SentenceTransformers embeddings, a text classifier and a utility function to place the data on the right device.

First we train a classifier on a small subset of the data. The aim of the classifier is to predict the news topic of each instance. Below we define a few simple training and evaluation functions.

We now split the data in 2 sets. The first set (x_train) we will use to train our text classifier, and the second set (x_drift) is held out to test our drift detector on.

Let's train our classifier. The classifier consists of a simple MLP head on top of a pre-trained SentenceTransformer model as the backbone. The SentenceTransformer remains frozen during training and only the MLP head is finetuned.

Detector calibration under no change

We start with an example where no drift occurs and the reference and test data are both sampled randomly from all news topics. Under this scenario, we expect no drift to be detected by either a normal MMD detector or by the context-aware MMD detector.

First we define some helper functions. The first one visualises the clustered text data while the second function samples disjoint reference and test sets with a specified number of instances per class (i.e. per news topic).

We first define the embedding model using the pre-trained SentenceTransformer embeddings and then embed both the reference and test sets.

By applying UMAP clustering on the SentenceTransformer embeddings, we can visually inspect the various news topic clusters. Note that we fit the clustering model on the held out data first, and then make predictions on the reference and test sets.

We can visually see that the reference and test set are made up of similar clusters of data, grouped by news topic. As a result, we would not expect drift to be flagged. If the data distribution did not change, we can expect the p-value distribution of our statistical test to be uniformly distributed between 0 and 1. So let's see if this assumption holds.

Importantly, first we need to define our context variable for the context-aware MMD detector. In our experiments we allow the relative prevalences of subpopulations to vary while the distributions underlying each of the subpopulations remain unchanged. To achieve this we condition on the prediction probabilities of the classifier we trained earlier to distinguish each of the 20 different news topics. We can do this because the prediction probabilities can account for the frequency of occurrence of each of the topics (be it imperfectly given our classifier makes the occasional mistake).

Before we set off our experiments, we embed all the instances in x_drift and compute all contexts c_drift so we don't have to call our transformer model every single pass in the for loop.

The below figure of the Q-Q (Quantile-Quantile) plots of a random sample from the uniform distribution U[0,1] against the obtained p-values from the vanilla and context-aware MMD detectors illustrate how well both detectors are calibrated. A perfectly calibrated detector should have a Q-Q plot which closely follows the diagonal. Only the middle plot in the grid shows the detector's p-values. The other plots correspond to n_runs p-values actually sampled from U[0,1] to contextualise how well the central plot follows the diagonal given the limited number of samples.

As expected we can see that both the normal MMD and the context-aware MMD detectors are well-calibrated.

Changing the relative subpopulation prevalence

We now focus our attention on a more realistic problem where the relative frequency of one or more subpopulations (i.e. news topics) is changing in a way which can be attributed to external events. Importantly, the distribution underlying each subpopulation (e.g. the distribution of hockey news itself) remains unchanged, only its frequency changes.

In our example we assume that the World Series and Stanley Cup coincide on the calendar leading to a spike in news articles on respectively baseball and hockey. Furthermore, there is not too much news on Mac or Windows since there are no new releases or products planned anytime soon.

While the context-aware detector remains well calibrated, the MMD detector consistently flags drift (low p-values). Note that this is the expected behaviour since the vanilla MMD detector cannot take any external context into account and correctly detects that the reference and test data do not follow the same underlying distribution.

We can also easily see this on the plot below where the p-values of the context-aware detector are uniformly distributed while the MMD detector's p-values are consistently close to 0. Note that we limited the y-axis range to make the plot easier to read.

Changing the subpopulation distribution

In the following example we change the distribution of one or more of the underlying subpopulations. Notice that now we do want to flag drift since our context variable, which permits changes in relative subpopulation prevalences, can no longer explain the change in distribution.

Imagine our news topic classification model is not as granular as before and instead of the 20 categories only predicts the 6 super classes, organised by subject matter:

  1. Computers: comp.graphics; comp.os.ms-windows.misc; comp.sys.ibm.pc.hardware; comp.sys.mac.hardware; comp.windows.x

  2. Recreation: rec.autos; rec.motorcycles; rec.sport.baseball; rec.sport.hockey

  3. Science: sci.crypt; sci.electronics; sci.med; sci.space

  4. Miscellaneous: misc.forsale

  5. Politics: talk.politics.misc; talk.politics.guns; talk.politics.mideast

  6. Religion: talk.religion.misc; talk.atheism; soc.religion.christian

What if baseball and hockey become less popular and the distribution underlying the Recreation class changes? We will want to detect this as the change in distributions of the subpopulations (the 6 super classes) cannot be explained anymore by the context variable.

In order to reuse our pretrained classifier for the super classes, we add the following helper function to map the predictions on the super classes and return one-hot encoded predictions over the 6 super classes. Note that our context variable now changes from a probability distribution over the 20 news topics to a one-hot encoded representation over the 6 super classes.

We can see that the context-aware detector is powerful to detect changes in the distributions of the subpopulations.

Detect unseen topics

Next we illustrate the effectiveness of the context-aware detector to detect new topics which are not present in the reference data. Obviously we also want to flag drift in this case. As an example we introduce movie reviews in the test data.

Changing the context variable

So far we have conditioned the context-aware detector on the model predictions. There are however many other useful contexts possible. One such example would be to condition on the predictions of an unsupervised clustering algorithm. To facilitate this, we first apply kernel PCA on the embedding vectors, followed by a Gaussian mixture model which clusters the data into 6 classes (same as the super classes). We will test both the calibration under the null hypothesis (no distribution change) as well as the power when a new topic (movie reviews) is injected.

Next we change the number of instances in each cluster between the reference and test sets. Note that we do not alter the underlying distribution of each of the clusters, just the frequency.

Now we run the experiment and show the context-aware detector's calibration when changing the cluster frequencies. We also show how the usual MMD detector will consistently flag drift. Furthermore, we inject instances from the movie reviews dataset and illustrate that the context-aware detector remains powerful when the underlying cluster distribution changes (by including a previously unseen topic).

Interpretability of the context-aware detector

The test statistic $\hat{t}$ of the context-aware MMD detector can be formulated as follows: $\hat{t} = \langle K_{0,0}, W_{0,0} \rangle + \langle K_{1,1}, W_{1,1} \rangle -2\langle K_{0,1}, W_{0,1}\rangle$ where $0$ refers to the reference data, $1$ to the test data, and $W_{.,.}$ and $K_{.,.}$ are the weight and kernel matrices, respectively. The weight matrices $W_{.,.}$ allow us to focus on the distribution's subpopulations of interest. Reference instances which have similar contexts as the test data will have higher values for their entries in $W_{0,1}$ than instances with dissimilar contexts. We can therefore interpret $W_{0,1}$ as the coupling matrix between instances in the reference and the test sets. This allows us to investigate which subpopulations from the reference set are present and which are missing in the test data. If we also have a good understanding of the model performance on various subpopulations of the reference data, we could even try and use this coupling matrix to roughly proxy model performance on the unlabeled test instances. Note that in this case we would require labels from the reference data and make sure the reference instances come from the validation, not the training set.

In the following example we only pick 2 classes to be present in the test set while all 20 are present in the reference set. We can then investigate via the coupling matrix whether the test statistic $\hat{t}$ focused on the right classes in the reference data via $W_{0,1}$. More concretely, we can sum over the columns (the test instances) of $W_{0,1}$ and check which reference instances obtained the highest weights.

Last updated

Was this helpful?