Geoff Ruddock

Identify root cause of NULLs using DecisionTreeClassifier

Goal

Suppose you have a complex dataset—one that you don’t fully understand—and you want to understand why one field is unexpectedly NULL in some rows.

You could identify potentially relevant dimensions and manually flip through each to see if their values are correlated with the missingness of your data.

But this can be tedious. And since you’d effectively be acting like a decision tree, why not try to solve the problem with a decision tree classifier?

💡 Inspiration: Anomalo » Root Causing Data Failures

Prep data

import numpy as np
import pandas as pd

Fetch

from sklearn.datasets import fetch_openml

df, _ = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)

EDA

df.isnull().mean().sort_values(ascending=False).head()
body         0.907563
cabin        0.774637
boat         0.628724
home.dest    0.430863
age          0.200917
dtype: float64

Transform

y = df['boat'].isnull().astype(int)
X = df.drop('boat', axis=1).pipe(pd.get_dummies)
from sklearn.impute import SimpleImputer

imp = SimpleImputer()
imp.fit(X)
X_imputed = imp.transform(X)

Model

from sklearn.tree import DecisionTreeClassifier, export_graphviz
import graphviz

Fit

model = DecisionTreeClassifier(
    random_state=42,
    min_weight_fraction_leaf=0.1
   # min_samples_leaf=100
)

model.fit(X_imputed, y, sample_weight=None)
DecisionTreeClassifier(min_weight_fraction_leaf=0.1, random_state=42)

Visualize

dot_data = export_graphviz(
    model,
    out_file=None,
    feature_names=list(X.columns),
    class_names=['Exists', 'Null'],
    filled=True,
    rounded=True,
    proportion=True,
)

graph = graphviz.Source(dot_data)
graph

svg


comments powered by Disqus