Model distillation is a technique that is used to transfer knowledge from a large network to a smaller network. Typically, it consists of training a second model with a simplified architecture on soft targets (the output distributions or the logits) obtained from the original model.
Here, we apply model distillation to obtain harmfulness scores, by comparing the output distributions of the original model with the output distributions of the distilled model, in order to detect adversarial data, malicious data drift or data corruption. We use the following definition of harmful and harmless data points:
Harmful data points are defined as inputs for which the model's predictions on the uncorrupted data are correct while the model's predictions on the corrupted data are wrong.
Harmless data points are defined as inputs for which the model's predictions on the uncorrupted data are correct and the model's predictions on the corrupted data remain correct.
Analogously to the adversarial AE detector, which is also part of the library, the model distillation detector picks up drift that reduces the performance of the classification model.
The detector can be used as follows:
Given an input $x,$ an adversarial score $S(x)$ is computed. $S(x)$ equals the value loss function employed for distillation calculated between the original model's output and the distilled model's output on $x$.
If $S(x)$ is above a threshold (explicitly defined or inferred from training data), the instance is flagged as adversarial.
Parameters:
threshold
: threshold value above which the instance is flagged as an adversarial instance.
distilled_model
: tf.keras.Sequential
instance containing the model used for distillation. Example:
model
: the classifier as a tf.keras.Model
. Example:
loss_type
: type of loss used for distillation. Supported losses: 'kld', 'xent'.
temperature
: Temperature used for model prediction scaling. Temperature <1 sharpens the prediction probability distribution which can be beneficial for prediction distributions with high entropy.
data_type
: can specify data type added to metadata. E.g. 'tabular' or 'image'.
Initialized detector example:
We then need to train the detector. The following parameters can be specified:
X
: training batch as a numpy array.
loss_fn
: loss function used for training. Defaults to the custom model distillation loss.
optimizer
: optimizer used for training. Defaults to Adam with learning rate 1e-3.
epochs
: number of training epochs.
batch_size
: batch size used during training.
verbose
: boolean whether to print training progress.
log_metric
: additional metrics whose progress will be displayed if verbose equals True.
preprocess_fn
: optional data preprocessing function applied per batch during training.
The threshold for the adversarial / harmfulness score can be set via infer_threshold
. We need to pass a batch of instances $X$ and specify what percentage of those we consider to be normal via threshold_perc
. Even if we only have normal instances in the batch, it might be best to set the threshold value a bit lower (e.g. $95$%) since the model could have misclassified training instances.
We detect adversarial / harmful instances by simply calling predict
on a batch of instances X
. We can also return the instance level score by setting return_instance_score
to True.
The prediction takes the form of a dictionary with meta
and data
keys. meta
contains the detector's metadata while data
is also a dictionary which contains the actual predictions stored in the following keys:
is_adversarial
: boolean whether instances are above the threshold and therefore adversarial instances. The array is of shape (batch size,).
instance_score
: contains instance level scores if return_instance_score
equals True.
Harmful drift detection through model distillation on CIFAR10
The adversarial detector follows the method explained in the Adversarial Detection and Correction by Matching Prediction Distributions paper. Usually, autoencoders are trained to find a transformation $T$ that reconstructs the input instance $x$ as accurately as possible with loss functions that are suited to capture the similarities between x and $x'$ such as the mean squared reconstruction error. The novelty of the adversarial autoencoder (AE) detector relies on the use of a classification model-dependent loss function based on a distance metric in the output space of the model to train the autoencoder network. Given a classification model $M$ we optimise the weights of the autoencoder such that the KL-divergence between the model predictions on $x$ and on $x'$ is minimised. Without the presence of a reconstruction loss term $x'$ simply tries to make sure that the prediction probabilities $M(x')$ and $M(x)$ match without caring about the proximity of $x'$ to $x$. As a result, $x'$ is allowed to live in different areas of the input feature space than $x$ with different decision boundary shapes with respect to the model $M$. The carefully crafted adversarial perturbation which is effective around x does not transfer to the new location of $x'$ in the feature space, and the attack is therefore neutralised. Training of the autoencoder is unsupervised since we only need access to the model prediction probabilities and the normal training instances. We do not require any knowledge about the underlying adversarial attack and the classifier weights are frozen during training.
The detector can be used as follows:
An adversarial score $S$ is computed. $S$ equals the K-L divergence between the model predictions on $x$ and $x'$.
If $S$ is above a threshold (explicitly defined or inferred from training data), the instance is flagged as adversarial.
For adversarial instances, the model $M$ uses the reconstructed instance $x'$ to make a prediction. If the adversarial score is below the threshold, the model makes a prediction on the original instance $x$.
This procedure is illustrated in the diagram below:
The method is very flexible and can also be used to detect common data corruptions and perturbations which negatively impact the model performance. The algorithm works well on tabular and image data.
Parameters:
threshold
: threshold value above which the instance is flagged as an adversarial instance.
encoder_net
: tf.keras.Sequential
instance containing the encoder network. Example:
decoder_net
: tf.keras.Sequential
instance containing the decoder network. Example:
ae
: instead of using a separate encoder and decoder, the AE can also be passed as a tf.keras.Model
.
model
: the classifier as a tf.keras.Model
. Example:
hidden_layer_kld
: dictionary with as keys the number of the hidden layer(s) in the classification model which are extracted and used during training of the adversarial AE, and as values the output dimension for the hidden layer. Extending the training methodology to the hidden layers is optional and can further improve the adversarial correction mechanism.
model_hl
: instead of passing a dictionary to hidden_layer_kld
, a list with tf.keras models for the hidden layer K-L divergence computation can be passed directly.
w_model_hl
: Weights assigned to the loss of each model in model_hl
. Also used to weight the K-L divergence contribution for each model in model_hl
when computing the adversarial score.
temperature
: Temperature used for model prediction scaling. Temperature <1 sharpens the prediction probability distribution which can be beneficial for prediction distributions with high entropy.
data_type
: can specify data type added to metadata. E.g. 'tabular' or 'image'.
Initialized adversarial detector example:
We then need to train the adversarial detector. The following parameters can be specified:
X
: training batch as a numpy array.
loss_fn
: loss function used for training. Defaults to the custom adversarial loss.
w_model
: weight on the loss term minimizing the K-L divergence between model prediction probabilities on the original and reconstructed instance. Defaults to 1.
w_recon
: weight on the mean squared error reconstruction loss term. Defaults to 0.
optimizer
: optimizer used for training. Defaults to Adam with learning rate 1e-3.
epochs
: number of training epochs.
batch_size
: batch size used during training.
verbose
: boolean whether to print training progress.
log_metric
: additional metrics whose progress will be displayed if verbose equals True.
preprocess_fn
: optional data preprocessing function applied per batch during training.
The threshold for the adversarial score can be set via infer_threshold
. We need to pass a batch of instances $X$ and specify what percentage of those we consider to be normal via threshold_perc
. Even if we only have normal instances in the batch, it might be best to set the threshold value a bit lower (e.g. $95$%) since the the model could have misclassified training instances leading to a higher score if the reconstruction picked up features from the correct class or some instances might look adversarial in the first place.
We detect adversarial instances by simply calling predict
on a batch of instances X
. We can also return the instance level adversarial score by setting return_instance_score
to True.
The prediction takes the form of a dictionary with meta
and data
keys. meta
contains the detector's metadata while data
is also a dictionary which contains the actual predictions stored in the following keys:
is_adversarial
: boolean whether instances are above the threshold and therefore adversarial instances. The array is of shape (batch size,).
instance_score
: contains instance level scores if return_instance_score
equals True.
We can immediately apply the procedure sketched out in the above diagram via correct
. The method also returns a dictionary with meta
and data
keys. On top of the information returned by detect
, 3 additional fields are returned under data
:
corrected
: model predictions by following the adversarial detection and correction procedure.
no_defense
: model predictions without the adversarial correction.
defense
: model predictions where each instance is corrected by the defense, regardless of the adversarial score.
The adversarial detector is based on Adversarial Detection and Correction by Matching Prediction Distributions. Usually, autoencoders are trained to find a transformation $T$ that reconstructs the input instance $x$ as accurately as possible with loss functions that are suited to capture the similarities between x and $x'$ such as the mean squared reconstruction error. The novelty of the adversarial autoencoder (AE) detector relies on the use of a classification model-dependent loss function based on a distance metric in the output space of the model to train the autoencoder network. Given a classification model $M$ we optimise the weights of the autoencoder such that the KL-divergence between the model predictions on $x$ and on $x'$ is minimised. Without the presence of a reconstruction loss term $x'$ simply tries to make sure that the prediction probabilities $M(x')$ and $M(x)$ match without caring about the proximity of $x'$ to $x$. As a result, $x'$ is allowed to live in different areas of the input feature space than $x$ with different decision boundary shapes with respect to the model $M$. The carefully crafted adversarial perturbation which is effective around x does not transfer to the new location of $x'$ in the feature space, and the attack is therefore neutralised. Training of the autoencoder is unsupervised since we only need access to the model prediction probabilities and the normal training instances. We do not require any knowledge about the underlying adversarial attack and the classifier weights are frozen during training.
The detector can be used as follows:
An adversarial score $S$ is computed. $S$ equals the K-L divergence between the model predictions on $x$ and $x'$.
If $S$ is above a threshold (explicitly defined or inferred from training data), the instance is flagged as adversarial.
For adversarial instances, the model $M$ uses the reconstructed instance $x'$ to make a prediction. If the adversarial score is below the threshold, the model makes a prediction on the original instance $x$.
This procedure is illustrated in the diagram below:
The method is very flexible and can also be used to detect common data corruptions and perturbations which negatively impact the model performance.
CIFAR10 consists of 60,000 32 by 32 RGB images equally distributed over 10 classes.
Note: in order to run this notebook, it is adviced to use Python 3.7 and have a GPU enabled.
Standardise the dataset by instance:
Check that the predictions on the test set reach $93.15$% accuracy:
We investigate both Carlini-Wagner (C&W) and SLIDE attacks. You can simply load previously found adversarial instances on the pretrained ResNet-56 model. The attacks are generated by using Foolbox:
Check if the prediction accuracy of the model on the adversarial instances is close to $0$%.
Let's visualise some adversarial instances:
We can again either fetch the pretrained detector from a Google Cloud Bucket or train one from scratch:
The detector first reconstructs the input instances which can be adversarial. The reconstructed input is then fed to the classifier if the adversarial score for the instance is above the threshold. Let's investigate what happens when we reconstruct attacked instances and make predictions on them:
Accuracy on attacked vs. reconstructed instances:
The detector restores the accuracy after the attacks from almost $0$% to well over $80$%! We can compute the adversarial scores and inspect some of the reconstructed instances:
The ROC curves and AUC values show the effectiveness of the adversarial score to detect adversarial instances:
The threshold for the adversarial score can be set via infer_threshold
. We need to pass a batch of instances $X$ and specify what percentage of those we consider to be normal via threshold_perc
. Assume we have only normal instances some of which the model has misclassified leading to a higher score if the reconstruction picked up features from the correct class or some might look adversarial in the first place. As a result, we set our threshold at $95$%:
Let's save the updated detector:
We can also load it easily as follows:
The correct
method of the detector executes the diagram in Figure 1. First the adversarial scores is computed. For instances where the score is above the threshold, the classifier prediction on the reconstructed instance is returned. Otherwise the original prediction is kept. The method returns a dictionary containing the metadata of the detector, whether the instances in the batch are adversarial (above the threshold) or not, the classifier predictions using the correction mechanism and both the original and reconstructed predictions. Let's illustrate this on a batch containing some adversarial (C&W) and original test set instances:
Let's check the model performance:
This can be improved with the correction mechanism:
We can further improve the correction performance by applying temperature scaling on the original model predictions $M(x)$ during both training and inference when computing the adversarial scores. We can again load a pretrained detector or train one from scratch:
Applying temperature scaling to CIFAR-10 improves the ROC curve and AUC values.
The performance of the correction mechanism can also be improved by extending the training methodology to one of the hidden layers of the classification model. We extract a flattened feature map from the hidden layer, feed it into a linear layer and apply the softmax function. The K-L divergence between predictions on the hidden layer for $x$ and $x'$ is optimised and included in the adversarial score during inference:
The adversarial detector proves to be very flexible and can be used to measure the harmfulness of the data drift on the classifier. We evaluate the detector on the CIFAR-10-C dataset (Hendrycks & Dietterich, 2019). The instances in CIFAR-10-C have been corrupted and perturbed by various types of noise, blur, brightness etc. at different levels of severity, leading to a gradual decline in model performance.
We can select from the following corruption types:
Fetch the CIFAR-10-C data for a list of corruptions at each severity level (from 1 to 5), make classifier predictions on the corrupted data, compute adversarial scores and identify which perturbations where malicious or harmful and which weren't. We can then store and visualise the adversarial scores for the harmful and harmless corruption. The score for the harmful perturbations is significantly higher than for the harmless ones. As a result, the adversarial detector also functions as a data drift detector.
Compute mean scores and standard deviation per severity level and plot: