Counterfactuals guided by prototypes on California housing dataset
This notebook goes through an example of prototypical counterfactuals using k-d trees to build the prototypes. Please check out this notebook for a more in-depth application of the method on MNIST using (auto-)encoders and trust scores.
In this example, we will train a simple neural net to predict whether house prices in California districts are above the median value or not. We can then find a counterfactual to see which variables need to be changed to increase or decrease a house price above or below the median value.
Note
To enable support for CounterfactualProto, you may need to run
pip install alibi[tensorflow]%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
tf.get_logger().setLevel(40) # suppress deprecation messages
tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import to_categorical
import os
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from alibi.explainers import CounterfactualProto
print('TF version: ', tf.__version__)
print('Eager execution enabled: ', tf.executing_eagerly()) # FalseLoad and prepare California housing dataset
0
8.3252
41.0
6.984127
1.023810
322.0
2.555556
37.88
-122.23
1
8.3014
21.0
6.238137
0.971880
2401.0
2.109842
37.86
-122.22
2
7.2574
52.0
8.288136
1.073446
496.0
2.802260
37.85
-122.24
3
5.6431
52.0
5.817352
1.073059
558.0
2.547945
37.85
-122.25
4
3.8462
52.0
6.281853
1.081081
565.0
2.181467
37.85
-122.25
Each row represents a whole census group. Explanation of features:
MedInc- median income in block groupHouseAge- median house age in block groupAveRooms- average number of rooms per householdAveBedrms- average number of bedrooms per householdPopulation- block group populationAveOccup- average number of household membersLatitude- block group latitudeLongitude- block group longitude
For more details on the dataset, refer to the scikit-learn documentation.
Transform into classification task: target becomes whether house price is above the overall median or not
Standardize data
Define train and test set
Train model
Generate counterfactual guided by the nearest class prototype
Original instance:
Run counterfactual:
The prediction flipped from 0 (value below the median) to 1 (above the median):
Let's take a look at the counterfactual. To make the results more interpretable, we will first undo the pre-processing step and then check where the counterfactual differs from the original instance:
So in order for the model to consider the census group as having above median house prices, the average occupancy would have to be lower by almost a whole household member, and the location of the census group would need to shift slightly South.
Comparing the original instance and the counterfactual side-by-side:
0
2.5313
30.0
5.039384
1.193493
1565.0
2.679795
35.14
-119.46
0
2.5313
30.0
5.039384
1.193493
1565.000004
1.77482
34.821144
-119.46
Clean up:
Last updated
Was this helpful?

