Counterfactuals guided by prototypes on California housing dataset
This notebook goes through an example of prototypical counterfactuals using k-d trees to build the prototypes. Please check out this notebook for a more in-depth application of the method on MNIST using (auto-)encoders and trust scores.
In this example, we will train a simple neural net to predict whether house prices in California districts are above the median value or not. We can then find a counterfactual to see which variables need to be changed to increase or decrease a house price above or below the median value.
Note
To enable support for CounterfactualProto, you may need to run
pip install alibi[tensorflow]
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical
import os
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from alibi.explainers import CounterfactualProto
print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # False
TF version: 2.7.4
Eager execution enabled: False
Load and prepare California housing dataset
california = fetch_california_housing(as_frame=True)
X = california.data.to_numpy()
target = california.target.to_numpy()
feature_names = california.feature_names
california.data.head()
0
8.3252
41.0
6.984127
1.023810
322.0
2.555556
37.88
-122.23
1
8.3014
21.0
6.238137
0.971880
2401.0
2.109842
37.86
-122.22
2
7.2574
52.0
8.288136
1.073446
496.0
2.802260
37.85
-122.24
3
5.6431
52.0
5.817352
1.073059
558.0
2.547945
37.85
-122.25
4
3.8462
52.0
6.281853
1.081081
565.0
2.181467
37.85
-122.25
Each row represents a whole census group. Explanation of features:
MedInc
- median income in block groupHouseAge
- median house age in block groupAveRooms
- average number of rooms per householdAveBedrms
- average number of bedrooms per householdPopulation
- block group populationAveOccup
- average number of household membersLatitude
- block group latitudeLongitude
- block group longitude
For more details on the dataset, refer to the scikit-learn documentation.
Transform into classification task: target becomes whether house price is above the overall median or not
y = np.zeros((target.shape[0],))
y[np.where(target > np.median(target))[0]] = 1
Standardize data
mu = X.mean(axis=0)
sigma = X.std(axis=0)
X = (X - mu) / sigma
Define train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
Train model
np.random.seed(42)
tf.random.set_seed(42)
def nn_model():
x_in = Input(shape=(8,))
x = Dense(40, activation='relu')(x_in)
x = Dense(40, activation='relu')(x)
x_out = Dense(2, activation='softmax')(x)
nn = Model(inputs=x_in, outputs=x_out)
nn.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
return nn
nn = nn_model()
nn.summary()
nn.fit(X_train, y_train, batch_size=64, epochs=500, verbose=0)
nn.save('nn_california.h5', save_format='h5')
nn = load_model('nn_california.h5')
score = nn.evaluate(X_test, y_test, verbose=0)
print('Test accuracy: ', score[1])
Test accuracy: 0.87863374
Generate counterfactual guided by the nearest class prototype
Original instance:
X = X_test[1].reshape((1,) + X_test[1].shape)
shape = X.shape
Run counterfactual:
# define model
nn = load_model('nn_california.h5')
# initialize and fit the explainer
cf = CounterfactualProto(nn, shape, use_kdtree=True, theta=10., max_iterations=1000,
feature_range=(X_train.min(axis=0), X_train.max(axis=0)),
c_init=1., c_steps=10)
cf.fit(X_train)
# generate a counterfactual
explanation = cf.explain(X)
The prediction flipped from 0 (value below the median) to 1 (above the median):
print(f'Original prediction: {explanation.orig_class}')
print(f'Counterfactual prediction: {explanation.cf["class"]}')
Original prediction: 0
Counterfactual prediction: 1
Let's take a look at the counterfactual. To make the results more interpretable, we will first undo the pre-processing step and then check where the counterfactual differs from the original instance:
orig = X * sigma + mu
counterfactual = explanation.cf['X'] * sigma + mu
delta = counterfactual - orig
for i, f in enumerate(feature_names):
if np.abs(delta[0][i]) > 1e-4:
print(f'{f}: {delta[0][i]}')
AveOccup: -0.9049749915631999
Latitude: -0.31885583625280134
So in order for the model to consider the census group as having above median house prices, the average occupancy would have to be lower by almost a whole household member, and the location of the census group would need to shift slightly South.
Comparing the original instance and the counterfactual side-by-side:
pd.DataFrame(orig, columns=feature_names)
0
2.5313
30.0
5.039384
1.193493
1565.0
2.679795
35.14
-119.46
pd.DataFrame(counterfactual, columns=feature_names)
0
2.5313
30.0
5.039384
1.193493
1565.000004
1.77482
34.821144
-119.46
Clean up:
os.remove('nn_california.h5')
Last updated
Was this helpful?