Counterfactual with Reinforcement Learning (CFRL) on MNIST
This method is described in Model-agnostic and Scalable Counterfactual Explanations via Reinforcement Learning and can generate counterfactual instances for any black-box model. The usual optimization procedure is transformed into a learnable process allowing to generate batches of counterfactual instances in a single forward pass even for high dimensional data. The training pipeline is model-agnostic and relies only on prediction feedback by querying the black-box model. Furthermore, the method allows target and feature conditioning.
We exemplify the use case for the TensorFlow backend. This means that all models: the autoencoder, the actor and the critic are TensorFlow models. Our implementation supports PyTorch backend as well.
CFRL uses Deep Deterministic Policy Gradient (DDPG) by interleaving a state-action function approximator called critic, with a learning an approximator called actor to predict the optimal action. The method assumes that the critic is differentiable with respect to the action argument, thus allowing to optimize the actor's parameters efficiently through gradient-based methods.
The DDPG algorithm requires two separate networks, an actor $\mu$ and a critic $Q$. Given the encoded representation of the input instance $z = enc(x)$, the model prediction $y_M$, the target prediction $y_T$ and the conditioning vector $c$, the actor outputs the counterfactual’s latent representation $z_{CF} = \mu(z, y_M, y_T, c)$. The decoder then projects the embedding $z_{CF}$ back to the original input space, followed by optional post-processing.
The training step consists of simultaneously optimizing the actor and critic networks. The critic regresses on the reward $R$ determined by the model prediction, while the actor maximizes the critic’s output for the given instance through $L_{max}$. The actor also minimizes two objectives to encourage the generation of sparse, in-distribution counterfactuals. The sparsity loss $L_{sparsity}$ operates on the decoded counterfactual $x_{CF}$ and combines the $L_1$ loss over the standardized numerical features and the $L_0$ loss over the categorical ones. The consistency loss $L_{consist}$ aims to encode the counterfactual $x_{CF}$ back to the same latent representation where it was decoded from and helps to produce in-distribution counterfactual instances.Formally, the actor's loss can be written as: $L_{actor} = L_{max} + \lambda_{1}L_{sparsity} + \lambda_{2}L_{consistency}$
Note
To enable support for CounterfactualRLTabular with tensorflow backend, you may need to run
pip install alibi[tensorflow]import os
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict
import tensorflow as tf
import tensorflow.keras as keras
from alibi.explainers import CounterfactualRL
from alibi.models.tensorflow import AE
from alibi.models.tensorflow import Actor, Critic
from alibi.models.tensorflow import MNISTEncoder, MNISTDecoder, MNISTClassifier
from alibi.explainers.cfrl_base import CallbackLoad MNIST dataset
Define and train CNN classifier
Define the predictor (black-box)
Now that we've trained the CNN classifier, we can define the black-box model. Note that the output of the black-box is a distribution which can be either a soft-label distribution (probabilities/logits for each class) or a hard-label distribution (one-hot encoding). Internally, CFRL takes the argmax. Moreover the output DOES NOT HAVE TO BE DIFFERENTIABLE.
Define and train autoencoder
Instead of directly modeling the perturbation vector in the potentially high-dimensional input space, we first train an autoencoder. The weights of the encoder are frozen and the actor applies the counterfactual perturbations in the latent space of the encoder. The pre-trained decoder maps the counterfactual embedding back to the input feature space.
The autoencoder follows a standard design. The model is composed from two submodules, the encoder and the decoder. The forward pass consists of passing the input to the encoder, obtain the input embedding and pass the embedding through the decoder.
Test the autoencoder

Counterfactual with Reinforcement Learning
Define and fit the explainer
Test explainer

Logging
Logging is clearly important when dealing with deep learning models. Thus, we provide an interface to write custom callbacks for logging purposes after each training step which we defined here. In the following cells we provide some example to log in Weights and Biases.
Logging reward callback
Logging images callback
Logging losses callback
Having defined the callbacks, we can define a new explainer that will include logging.
Last updated
Was this helpful?

