githubEdit

Counterfactual with Reinforcement Learning (CFRL) on Adult Census

This method is described in Model-agnostic and Scalable Counterfactual Explanations via Reinforcement Learningarrow-up-right and can generate counterfactual instances for any black-box model. The usual optimization procedure is transformed into a learnable process allowing to generate batches of counterfactual instances in a single forward pass even for high dimensional data. The training pipeline is model-agnostic and relies only on prediction feedback by querying the black-box model. Furthermore, the method allows target and feature conditioning.

We exemplify the use case for the TensorFlow backend. This means that all models: the autoencoder, the actor and the critic are TensorFlow models. Our implementation supports PyTorch backend as well.

CFRL uses Deep Deterministic Policy Gradient (DDPG)arrow-up-right by interleaving a state-action function approximator called critic, with a learning an approximator called actor to predict the optimal action. The method assumes that the critic is differentiable with respect to the action argument, thus allowing to optimize the actor's parameters efficiently through gradient-based methods.

The DDPG algorithm requires two separate networks, an actor $\mu$ and a critic $Q$. Given the encoded representation of the input instance $z = enc(x)$, the model prediction $y_M$, the target prediction $y_T$ and the conditioning vector $c$, the actor outputs the counterfactual’s latent representation $z_{CF} = \mu(z, y_M, y_T, c)$. The decoder then projects the embedding $z_{CF}$ back to the original input space, followed by optional post-processing.

The training step consists of simultaneously optimizing the actor and critic networks. The critic regresses on the reward $R$ determined by the model prediction, while the actor maximizes the critic’s output for the given instance through $L_{max}$. The actor also minimizes two objectives to encourage the generation of sparse, in-distribution counterfactuals. The sparsity loss $L_{sparsity}$ operates on the decoded counterfactual $x_{CF}$ and combines the $L_1$ loss over the standardized numerical features and the $L_0$ loss over the categorical ones. The consistency loss $L_{consist}$ aims to encode the counterfactual $x_{CF}$ back to the same latent representation where it was decoded from and helps to produce in-distribution counterfactual instances. Formally, the actor's loss can be written as: $L_{actor} = L_{max} + \lambda_{1}L_{sparsity} + \lambda_{2}L_{consistency}$

This example will use the xgboostarrow-up-right library, which can be installed with:

Note

To enable support for CounterfactualRLTabular with tensorflow backend, you may need to run

pip install alibi[tensorflow]
import os
import numpy as np
import pandas as pd
from copy import deepcopy
from typing import List, Tuple, Dict, Callable

import tensorflow as tf
import tensorflow.keras as keras

from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

from alibi.explainers import CounterfactualRLTabular, CounterfactualRL
from alibi.datasets import fetch_adult
from alibi.models.tensorflow import HeAE
from alibi.models.tensorflow import Actor, Critic
from alibi.models.tensorflow import ADULTEncoder, ADULTDecoder
from alibi.explainers.cfrl_base import Callback
from alibi.explainers.backends.cfrl_tabular import get_he_preprocessor, get_statistics, \
    get_conditional_vector, apply_category_mapping

Load Adult Census Dataset

Train black-box classifier

Define the predictor (black-box)

Now that we've trained the classifier, we can define the black-box model. Note that the output of the black-box is a distribution which can be either a soft-label distribution (probabilities/logits for each class) or a hard-label distribution (one-hot encoding). Internally, CFRL takes the argmax. Moreover the output DOES NOT HAVE TO BE DIFFERENTIABLE.

Define and train autoencoder

Instead of directly modelling the perturbation vector in the potentially high-dimensional input space, we first train an autoencoder. The weights of the encoder are frozen and the actor applies the counterfactual perturbations in the latent space of the encoder. The pre-trained decoder maps the counterfactual embedding back to the input feature space.

The autoencoder follows a standard design. The model is composed from two submodules, the encoder and the decoder. The forward pass consists of passing the input to the encoder, obtain the input embedding and pass the embedding through the decoder.

The heterogeneous variant used in this example uses an additional type checking to ensure that the output of the decoder is a list of tensors.

Heterogeneous dataset require special treatment. In this work we modeled the numerical features by normal distributions with constant standard deviation and categorical features by categorical distributions. Due to the choice of feature modeling, some numerical features can end up having different types than the original numerical features. For example, a feature like Age having the type of int can become a float due to the autoencoder reconstruction (e.g., Age=26 -> Age=26.3). This behavior can be undesirable. Thus we performed casting when process the output of the autoencoder (decoder component).

Counterfactual with Reinforcement Learning

Define dataset specific attributes and constraints

A desirable property of a method for generating counterfactuals is to allow feature conditioning. Real-world datasets usually include immutable features such as Sex or Race, which should remain unchanged throughout the counterfactual search procedure. Similarly, a numerical feature such as Age should only increase for a counterfactual to be actionable.

Define and fit the explainer

Test explainer

Age
Workclass
Education
Marital Status
Occupation
Relationship
Race
Sex
Capital Gain
Capital Loss
Hours per week
Country
Label

0

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

7298

0

40

United-States

>50K

1

35

Private

High School grad

Married

White-Collar

Husband

White

Male

7688

0

50

United-States

>50K

2

39

State-gov

Masters

Married

Professional

Wife

White

Female

5178

0

38

United-States

>50K

3

44

Self-emp-inc

High School grad

Married

Sales

Husband

White

Male

0

0

50

United-States

>50K

4

39

Private

Bachelors

Separated

White-Collar

Not-in-family

White

Female

13550

0

50

United-States

>50K

5

45

Private

High School grad

Married

Blue-Collar

Husband

White

Male

0

1902

40

?

>50K

6

50

Private

Bachelors

Married

Professional

Husband

White

Male

0

0

50

United-States

>50K

7

29

Private

Bachelors

Married

White-Collar

Wife

White

Female

0

0

50

United-States

>50K

8

47

Private

Bachelors

Married

Professional

Husband

White

Male

0

0

50

United-States

>50K

9

35

Private

Bachelors

Married

White-Collar

Husband

White

Male

0

0

70

United-States

>50K

Age
Workclass
Education
Marital Status
Occupation
Relationship
Race
Sex
Capital Gain
Capital Loss
Hours per week
Country
Label

0

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

320

0

40

United-States

<=50K

1

35

Private

Dropout

Married

Blue-Collar

Husband

White

Male

125

0

50

United-States

<=50K

2

39

State-gov

Dropout

Married

Service

Wife

White

Female

538

15

39

United-States

<=50K

3

44

Self-emp-inc

High School grad

Married

Sales

Husband

White

Male

0

0

50

United-States

>50K

4

39

Private

Bachelors

Separated

White-Collar

Not-in-family

White

Female

1922

0

51

United-States

<=50K

5

45

Private

High School grad

Married

Blue-Collar

Husband

White

Male

0

1900

41

Latin-America

>50K

6

50

Private

Dropout

Married

Service

Husband

White

Male

0

0

51

United-States

<=50K

7

29

Private

Dropout

Married

Sales

Wife

White

Female

0

0

50

United-States

<=50K

8

47

Private

Dropout

Married

Service

Husband

White

Male

0

0

51

United-States

<=50K

9

35

Private

Dropout

Married

Sales

Husband

White

Male

0

0

71

United-States

<=50K

Diversity

Age
Workclass
Education
Marital Status
Occupation
Relationship
Race
Sex
Capital Gain
Capital Loss
Hours per week
Country
Label

0

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

7298

0

40

United-States

>50K

Age
Workclass
Education
Marital Status
Occupation
Relationship
Race
Sex
Capital Gain
Capital Loss
Hours per week
Country
Label

0

60

Private

Dropout

Married

Blue-Collar

Husband

White

Male

143

0

40

United-States

<=50K

1

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

49

0

40

United-States

<=50K

2

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

84

0

40

United-States

<=50K

3

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

87

0

41

United-States

<=50K

4

60

Private

High School grad

Married

Blue-Collar

Husband

White

Male

97

0

40

United-States

<=50K

Logging

Logging is clearly important when dealing with deep learning models. Thus, we provide an interface to write custom callbacks for logging purposes after each training step which we defined herearrow-up-right. In the following cells we provide some example to log in Weights and Biases.

Logging reward callback

Logging losses callback

Logging tables callback

Having defined the callbacks, we can define a new explainer that will include logging.

Last updated

Was this helpful?