Context-aware drift detection on ECGs
Introduction
In this notebook we show how to detect drift on ECG data given a specific context using the context-aware MMD detector (Cobb and Van Looveren, 2022). Consider the following simple example: we have a heatbeat monitoring system which is trained on a wide variety of heartbeats sampled from people of all ages across a variety of activities (e.g. rest or running). Then we deploy the system to monitor individual people during certain activities. The distribution of the heartbeats monitored during deployment will then be drifting against the reference data which resembles the full training distribution, simply because only individual people in a specific setting are being tracked. However, this does not mean that the system is not working and requires re-training. We are instead interested in flagging drift given the relevant context such as the person's characteristics (e.g. age or medical history) and the activity. Traditional drift detectors cannot flexibly deal with this setting since they rely on the i.i.d. assumption when sampling the reference and test sets. The context-aware detector however allows us to pass this context to the detector and flag drift appropriately. More generally, the context-aware drift detector detects changes in the data distribution which cannot be attributed to a permissible change in the context variable. On top of that, the detector allows you to understand which subpopulations are present in both the reference and test data which provides deeper insights into the distribution underlying the test data.
Useful context (or conditioning) variables for the context-aware drift detector include but are not limited to:
Domain or application specific contexts such as the time of day or the activity (e.g. running or resting).
Conditioning on the relative prevalences of known subpopulations, such as the frequency of different types of heartbeats. It is important to note that while the relative frequency of each subpopulation (e.g. the different heartbeat types) might change, the distribution underlying each individual subpopulation (e.g. each specific type of heartbeat) cannot change.
Conditioning on model predictions. Assume we trained a classifier which detects arrhythmia, then we can provide the classifier model predictions as context and understand if, given the model prediction, the data comes from the same underlying distribution as the reference data or not.
Conditioning on model uncertainties which would allow increases in model uncertainty due to drift into familiar regions of high aleatoric uncertainty (often fine) to be distinguished from that into unfamiliar regions of high epistemic uncertainty (often problematic).
The following settings will be showcased throughout the notebook:
A change in the prevalences of subpopulations (i.e. different types of heartbeats as determined by an unsupervised clustering model or an ECG classifier) which are also present in the reference data is observed. Contrary to traditional drift detection approaches, the context-aware detector does not flag drift as this change in frequency of various heartbeats is permissible given the context provided.
A change in the underlying distribution underlying one or more subpopulations takes place. While we allow changes in the prevalences of the subpopulations accounted for by the context variable, we do not allow changes of the subpopulations themselves. If for instance the ECGs are corrupted by noise on the sensor measurements, we want to flag drift.
We also show how to condition the detector on different context variables such as the ECG classifier model predictions, cluster membership by an unsupervised clustering algorithm and timestamps.
Under setting 1. we want our detector to be well-calibrated (a controlled False Positive Rate (FPR) and more generally a p-value which is uniformly distributed between 0 and 1) while under setting 2. we want our detector to be powerful and flag drift. Lastly, we show how the detector can help you to understand the connection between the reference and test data distributions better.
Data
The dataset contains 5000 ECG’s, originally obtained from Physionet from the BIDMC Congestive Heart Failure Database, record chf07. The data has been pre-processed in 2 steps: first each heartbeat is extracted, and then each beat is made equal length via interpolation. The data is labeled and contains 5 classes. The first class $N$ which contains almost 60% of the observations is seen as normal while the others are supraventricular ectopic beats ($S$), ventricular ectopic beats ($V$), fusion beats ($F$) and unknown beats ($Q$).
Requirements
The notebook requires the torch and statsmodels packages to be installed, which can be done via pip:
Before we start let's fix the random seeds for reproducibility:
Load data
First we load the data, show the distribution across the ECG classes and visualise some ECGs from each class.
We can see that most heartbeats can be classified as normal, followed by the unknown class. We will now sample 500 heartbeats to train a simple ECG classifier. Importantly, we leave out the $F$ and $V$ classes which are used to detect drift. First we define a helper function to sample data.
We use a prop_train fraction of all samples to train the classifier and then remove instances from the $F$ and $V$ classes. The rest of the data is used by our drift detectors.
Train an ECG classifier
Now we define and train our classifier on the training set.
Let's evaluate out classifier on both the training and drift portions of the datasets.
Detector calibration under no change
We start with an example where no drift occurs and the reference and test data are both sampled randomly from all classes present in the reference data (classes 0, 1 and 3). Under this scenario, we expect no drift to be detected by either a normal MMD detector or by the context-aware MMD detector.
Before we can start using the context-aware drift detector, first we need to define our context variable. In our experiments we allow the relative prevalences of subpopulations (i.e. the relative frequency of different types of hearbeats also present in the reference data) to vary while the distributions underlying each of the subpopulations remain unchanged. To achieve this we condition on the prediction probabilities of the classifier we trained earlier to distinguish the different types of ECGs. We can do this because the prediction probabilities can account for the frequency of occurrence of each of the heartbeat types (be it imperfectly given our classifier makes the occasional mistake).
The below figure of the Q-Q (Quantile-Quantile) plots of a random sample from the uniform distribution U[0,1] against the obtained p-values from the vanilla and context-aware MMD detectors illustrate how well both detectors are calibrated. A perfectly calibrated detector should have a Q-Q plot which closely follows the diagonal. Only the middle plot in the grid shows the detector's p-values. The other plots correspond to n_runs p-values actually sampled from U[0,1] to contextualise how well the central plot follows the diagonal given the limited number of samples.
As expected we can see that both the normal MMD and the context-aware MMD detectors are well-calibrated.
Changing the relative subpopulation prevalences
We now focus our attention on a more realistic problem where the relative frequency of one or more subpopulations (i.e. types of hearbeats) is changing while the underlying subpopulation distribution stays the same. This would be the expected setting when we monitor the heartbeat of a specific person (e.g. only normal heartbeats) and we don't want to flag drift.
While the usual MMD detector only returns very low p-values (mostly 0), the context-aware MMD detector remains calibrated.
Changing the subpopulation distribution
In the following example we change the distribution of one or more of the underlying subpopulations (i.e. the different types of heartbeats). Notice that now we do want to flag drift since our context variable, which permits changes in relative subpopulation prevalences, can no longer explain the change in distribution.
We will again sample from the normal heartbeats, but now we will add random noise to a fraction of the extracted heartbeats to change the distribution. This could be the result of an error with some of the sensors. The perturbation is illustrated below:
As we can see from the Q-Q and power of the detector, the changes in the subpopulation are easily detected:
Changing the context variable
We now use the cluster membership probabilities of a Gaussian mixture model which is fit on the training instances as context variables instead of the model predictions. We will test both the calibration when the frequency of the subpopulations (the cluster memberships) changes as well as the power when the $F$ and $V$ heartbeats are included.
Interpretability of the context-aware detector
The test statistic $\hat{t}$ of the context-aware MMD detector can be formulated as follows: $\hat{t} = \langle K_{0,0}, W_{0,0} \rangle + \langle K_{1,1}, W_{1,1} \rangle -2\langle K_{0,1}, W_{0,1}\rangle$ where $0$ refers to the reference data, $1$ to the test data, and $W_{.,.}$ and $K_{.,.}$ are the weight and kernel matrices, respectively. The weight matrices $W_{.,.}$ allow us to focus on the distribution's subpopulations of interest. Reference instances which have similar contexts as the test data will have higher values for their entries in $W_{0,1}$ than instances with dissimilar contexts. We can therefore interpret $W_{0,1}$ as the coupling matrix between instances in the reference and the test sets. This allows us to investigate which subpopulations from the reference set are present and which are missing in the test data. If we also have a good understanding of the model performance on various subpopulations of the reference data, we could even try and use this coupling matrix to roughly proxy model performance on the unlabeled test instances. Note that in this case we would require labels from the reference data and make sure the reference instances come from the validation, not the training set.
In the following example we only pick 1 type of heartbeat (the normal one) to be present in the test set while 3 types are present in the reference set. We can then investigate via the coupling matrix whether the test statistic $\hat{t}$ focused on the right types of heartbeats in the reference data via $W_{0,1}$. More concretely, we can sum over the columns (the test instances) of $W_{0,1}$ and check which reference instances obtained the highest weights.
As expected no drift was detected since the test set only contains normal heartbeats. We now sort the weights of w_ref in descending order. We expect the top 400 entries to be fairly high and consistent since these represent the normal heartbeats in the reference set. Afterwards, the weight attribution to the other instances in the reference set should be low. The plot below confirms that this is indeed what happens.
Time conditioning
The dataset consists of nicely extracted and aligned ECGs of 140 data points for each observation. However in reality it is likely that we will continuously or periodically observe instances which are not nicely aligned. We could however assign a timestamp to the data (e.g. starting from a peak) and use time as the context variable. This is illustrated in the example below.
First we create a new dataset where we split each instance in slices of non-overlapping ECG segments. Each of the segments will have an associated timestamp as context variable. Then we can check the calibration under no change (besides the time-varying behaviour which is accounted for) as well as the power for ECG segments where we add incorrect time stamps to some of the segments.
Last updated
Was this helpful?

