Skip to main content

Shenzhen, China

Using Tensorflow Models in OpenCV

Scikit-Learn ML Model Explainability

SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions.

import boto3
from datetime import date
import lightgbm as lgb
import math
import matplotlib.pyplot as plt
import numpy as pd
import pandas as pd
import seaborn as sns
import shap

from sklearn import metrics
from sklearn.model_selection import StratifiedShuffleSplit, GridSearchCV, train_test_split

Using tqdm.autonotebook.tqdm in notebook mode. Use tqdm.tqdm instead to force console mode (e.g. in jupyter console)

Dataset

df_train = pd.read_csv('data/train.csv')
df_test = pd.read_csv('data/test.csv')

df = pd.concat([df_train, df_test])
df.shape

(21287, 14)

df.head().transpose()
01234
enrollee_id8949297251156133241666
citycity_103city_40city_21city_115city_162
city_development_index0.920.7760.6240.7890.767
genderMaleMaleNaNNaNMale
relevent_experienceHas relevent experienceNo relevent experienceNo relevent experienceNo relevent experienceHas relevent experience
enrolled_universityno_enrollmentno_enrollmentFull time courseNaNno_enrollment
education_levelGraduateGraduateGraduateGraduateMasters
major_disciplineSTEMSTEMSTEMBusiness DegreeSTEM
experience>20155<1>20
company_sizeNaN50-99NaNNaN50-99
company_typeNaNPvt LtdNaNPvt LtdFunded Startup
last_new_job1>4nevernever4
training_hours364783528
target1.00.00.01.00.0

Preprocessing

# can ID be used as unique identifier?
df['enrollee_id'].nunique() == df.shape[0]
# there are as many IDs as there are rows in the dataset

True

df['duplicated'] = df.drop(columns=['enrollee_id']).duplicated()
df['duplicated'].value_counts()
# there are 50 ids that are dups

False 21237 True 50 Name: duplicated, dtype: int64

# drop all lines that are duplicated entries
df.drop(df[df['duplicated'] == True].index, inplace = True)
df['duplicated'].value_counts()

False 21235 Name: duplicated, dtype: int64

for col in df.columns:
print(f' \n {col.upper()} :: {df[col].nunique()} UNIQUE :: {df[col].isnull().sum()} NULL')

ENROLLEE_ID :: 21235 UNIQUE :: 0 NULL

CITY :: 123 UNIQUE :: 0 NULL

CITY_DEVELOPMENT_INDEX :: 93 UNIQUE :: 0 NULL

GENDER :: 3 UNIQUE :: 5015 NULL

RELEVENT_EXPERIENCE :: 2 UNIQUE :: 0 NULL

ENROLLED_UNIVERSITY :: 3 UNIQUE :: 417 NULL

EDUCATION_LEVEL :: 5 UNIQUE :: 512 NULL

MAJOR_DISCIPLINE :: 6 UNIQUE :: 3121 NULL

EXPERIENCE :: 22 UNIQUE :: 70 NULL

COMPANY_SIZE :: 8 UNIQUE :: 6541 NULL

COMPANY_TYPE :: 6 UNIQUE :: 6755 NULL

LAST_NEW_JOB :: 6 UNIQUE :: 463 NULL

TRAINING_HOURS :: 241 UNIQUE :: 0 NULL

TARGET :: 2 UNIQUE :: 2127 NULL

DUPLICATED :: 1 UNIQUE :: 0 NULL

# remove all entires that don't have the target value
df.dropna(subset=['target'], how='all', inplace=True)
df['target'].value_counts()
# we need a model that can handle imbalanced datasets -> gradient boosted decision trees (lightgbm)

0.0 14343 1.0 4765 Name: target, dtype: int64

df['company_size'].value_counts()

50-99 3080 100-500 2560 10000+ 2009 10/49 1470 1000-4999 1322 <10 1308 500-999 877 5000-9999 562 Name: company_size, dtype: int64

df['company_size'] = df['company_size'].apply(lambda x: '10-49' if x == '10/49' else x)
df['company_size'].value_counts()

50-99 3080 100-500 2560 10000+ 2009 10-49 1470 1000-4999 1322 <10 1308 500-999 877 5000-9999 562 Name: company_size, dtype: int64

df['training_hours'].hist()

png

sns.boxplot(data=df, x='education_level', y='training_hours', hue='target')

png

# drop columns that will not have an influence on the target variable
df_train = df.drop(['duplicated', 'enrollee_id'], axis=1)
# make target variable int
df_train['target'] = pd.to_numeric(df_train['target'])
df_train['target'] = df_train['target'].astype(int)

Data Transformation

# encode categorical columns
categorical_columns = []

for col in df_train.columns:
print('\n')
print(f'INFO :: Column {col} is of type {df_train[col].dtype} .')
if df_train[col].dtype == 'object':
print('PROCESSING :: Added for dummy variable creation.')
categorical_columns.append(col)

df_train = pd.get_dummies(df_train, columns=categorical_columns)

INFO :: Column city is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column city_development_index is of type float64 .

INFO :: Column gender is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column relevent_experience is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column enrolled_university is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column education_level is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column major_discipline is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column experience is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column company_size is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column company_type is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column last_new_job is of type object . PROCESSING :: Added for dummy variable creation.

INFO :: Column training_hours is of type int64 .

INFO :: Column target is of type int64 .

# correlation to target
def gen_correlation_df(dataframe, target):
correlations = {
'variable': [],
'corr': []
}

for col in list(dataframe.columns):
if dataframe[col].dtype not in ['object', '<M8[ns]'] and col != target:
r = dataframe[target].corr(dataframe[col])
correlations['variable'].append(col)
correlations['corr'].append(r)

return pd.DataFrame(correlations).sort_values(by=['corr'], ascending = False)
df_cor = gen_correlation_df(df_train, 'target')
df_cor
variablecorr
66city_city_210.319434
130enrolled_university_Full time course0.149529
129relevent_experience_No relevent experience0.127994
13city_city_110.091457
133education_level_Graduate0.087214
164experience_<10.079047
144experience_10.069583
157experience_30.066304
26city_city_1280.050778
143major_discipline_STEM0.050013
158experience_40.049482
5city_city_1010.047004
155experience_20.046997
104city_city_740.046142
185last_new_job_never0.045534
41city_city_1450.044923
46city_city_1550.033568
180last_new_job_10.029297
75city_city_330.027429
81city_city_420.026713
24city_city_1260.025354
159experience_50.025253
108city_city_780.024855
161experience_70.024617
63city_city_190.023563
34city_city_1390.020590
100city_city_700.020259
160experience_60.019471
82city_city_430.019350
86city_city_480.017434
16city_city_1150.017167
11city_city_1070.017094
118city_city_900.015416
39city_city_1430.015099
60city_city_1790.013111
56city_city_1710.012551
119city_city_910.011924
121city_city_940.011539
125gender_Female0.008742
69city_city_250.008382
52city_city_1620.007534
19city_city_1180.007300
48city_city_1580.006651
38city_city_1420.006401
142major_discipline_Other0.006048
83city_city_440.005959
109city_city_790.005943
42city_city_1460.005943
88city_city_530.004976
85city_city_460.004568
12city_city_1090.004213
10city_city_1060.004213
29city_city_1310.004213
103city_city_730.004204
139major_discipline_Business Degree0.004155
106city_city_760.003626
18city_city_1170.003517
59city_city_1760.003467
115city_city_840.003467
31city_city_1340.003260
127gender_Other0.002882
92city_city_590.002678
55city_city_1670.002678
23city_city_1230.002450
22city_city_1210.002432
40city_city_1440.002387
90city_city_550.002275
117city_city_90.002016
131enrolled_university_Part time course0.001623
62city_city_1800.001608
35city_city_140.000056
61city_city_180.000021
74city_city_310.000021
54city_city_1660.000021
44city_city_150-0.000435
141major_discipline_No Major-0.000687
116city_city_89-0.001449
177company_type_Other-0.001790
89city_city_54-0.002196
58city_city_175-0.002196
43city_city_149-0.002384
30city_city_133-0.002611
25city_city_127-0.002611
162experience_8-0.003014
68city_city_24-0.003108
70city_city_26-0.003364
21city_city_120-0.003388
4city_city_100-0.003633
45city_city_152-0.004028
36city_city_140-0.004170
53city_city_165-0.004531
112city_city_81-0.004713
120city_city_93-0.004869
156experience_20-0.005391
99city_city_7-0.005581
20city_city_12-0.005761
174company_type_Early Stage Startup-0.005792
71city_city_27-0.006643
77city_city_37-0.006666
14city_city_111-0.007223
27city_city_129-0.007223
80city_city_41-0.007451
111city_city_80-0.007518
181last_new_job_2-0.007570
51city_city_160-0.007645
73city_city_30-0.008166
110city_city_8-0.008340
113city_city_82-0.008340
37city_city_141-0.008801
47city_city_157-0.008868
98city_city_69-0.009085
94city_city_62-0.009325
146experience_11-0.009632
166company_size_10-49-0.010248
138major_discipline_Arts-0.010679
64city_city_2-0.011034
102city_city_72-0.011176
2city_city_1-0.011431
65city_city_20-0.013150
78city_city_39-0.013833
182last_new_job_3-0.013902
79city_city_40-0.014132
28city_city_13-0.014426
124city_city_99-0.014593
17city_city_116-0.014710
183last_new_job_4-0.015328
178company_type_Public Sector-0.015365
49city_city_159-0.015976
140major_discipline_Humanities-0.016998
163experience_9-0.017082
114city_city_83-0.017769
148experience_13-0.019034
84city_city_45-0.019214
95city_city_64-0.019522
9city_city_105-0.020172
107city_city_77-0.020649
91city_city_57-0.020956
145experience_10-0.021133
1training_hour "line": 1590, s-0.021179
154experience_19-0.022050
123city_city_98-0.022057
3city_city_10-0.022491
6city_city_102-0.023017
147experience_12-0.023773
176company_type_NGO-0.024453
96city_city_65-0.024939
57city_city_173-0.025485
152experience_17-0.025806
153experience_18-0.027003
172company_size_5000-9999-0.027311
149experience_14-0.027454
101city_city_71-0.028219
87city_city_50-0.028243
122city_city_97-0.029486
33city_city_138-0.030509
137education_level_Primary School-0.034393
76city_city_36-0.035708
93city_city_61-0.036076
67city_city_23-0.036596
150experience_15-0.036696
136education_level_Phd-0.037587
171company_size_500-999-0.038555
72city_city_28-0.039877
151experience_16-0.040968
97city_city_67-0.041123
134education_level_High School-0.042942
105city_city_75-0.043489
135education_level_Masters-0.043751
7city_city_103-0.044411
8city_city_104-0.045717
169company_size_10000+-0.046135
173company_size_<10-0.048945
175company_type_Funded Startup-0.059407
32city_city_136-0.059726
168company_size_1000-4999-0.061808
184last_new_job_>4-0.069695
170company_size_50-99-0.073381
126gender_Male-0.074494
167company_size_100-500-0.080041
50city_city_16-0.090270
15city_city_114-0.094945
165experience_>20-0.100802
128relevent_experience_Has relevent experience-0.127994
132enrolled_university_no_enrollment-0.140493
179company_type_Pvt Ltd-0.161860
0city_development_index-0.341331

Model Training

# train-test split
train, test = train_test_split(df_train, test_size=0.2)
# remove target label from training dataset
X_cols = list(set(train.columns) - set(['target']))
def split_xy(train_in, test_in, X_cols, target):
train = train_in.copy()
test = test_in.copy()

X_train = train[X_cols].copy()
X_test = test[X_cols].copy()

y_train = train[target]
y_test = test[target]

return X_train, X_test, y_train, y_test
X_train, X_test, y_train, y_test = split_xy(test, train, X_cols, 'target')
# verify 80/20 train-test split
print(X_test.shape[0] / (X_train.shape[0] + X_test.shape[0]) * 100)

79.99790663596399

# create shuffled splits
cv_split = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=77)
# create search params
param_grid = {
'learning_rate': [0.05],
'n_estimators': [400],
'max_depth': [5, 7],
'min_child_samples': [50, 200, 600],
'num_leaves': [31, 63]
}
def fit_lgb(param_grid, X, y, cv):
gbm = lgb.LGBMClassifier(
objective='binary',
feature_fraction=0.9
)

grid_search = GridSearchCV(
gbm,
param_grid=param_grid,
scoring='roc_auc',
cv=cv,
verbose=1,
n_jobs=-1
)

grid_search.fit(X=X, y=y.values.ravel())

print(grid_search.best_score_)
print(grid_search.best_estimator_)
print(grid_search.best_params_)

return grid_search
grid_search = fit_lgb(param_grid, X_train, y_train, cv_split)

# 0.8067351303514142
# LGBMClassifier(feature_fraction=0.9, learning_rate=0.05, max_depth=5,
# min_child_samples=50, n_estimators=400, objective='binary')
# {'learning_rate': 0.05, 'max_depth': 5, 'min_child_samples': 50, 'n_estimators': 400, 'num_leaves': 31}

Model Evaluation

# pass test data to trained model
y_pred = grid_search.best_estimator_.predict_proba(X_test)[:,1]

[LightGBM] [Warning] feature_fraction is set=0.9, colsample_bytree=1.0 will be ignored. Current value: feature_fraction=0.9 [LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).

auc = metrics.roc_auc_score(y_true=y_test, y_score=y_pred)
auc

0.7834230893204323

feature_importance = pd.Series(grid_search.best_estimator_.feature_importances_, index=X_cols).sort_values(ascending=False)
feature_importance

training_hours 1530 city_development_index 651 last_new_job_1 208 company_type_Pvt Ltd 205 education_level_Graduate 126 company_size_50-99 112 relevent_experience_No relevent experience 109 city_city_103 102 enrolled_university_no_enrollment 85 major_discipline_STEM 80 experience_3 79 last_new_job_never 77 city_city_21 73 company_type_Funded Startup 68 gender_Male 66 company_type_Public Sector 64 company_size_1000-4999 62 company_size_<10 58 experience_>20 56 company_size_100-500 56 gender_Female 56 enrolled_university_Full time course 54 company_type_Early Stage Startup 54 education_level_Masters 53 experience_4 52 last_new_job_2 51 experience_<1 49 last_new_job_4 47 company_size_10000+ 45 company_type_NGO 44 relevent_experience_Has relevent experience 42 company_size_500-999 40 experience_2 39 experience_8 30 enrolled_university_Part time course 30 city_city_100 29 experience_9 27 company_size_5000-9999 25 experience_17 25 company_size_10-49 24 city_city_16 23 experience_11 23 major_discipline_Humanities 23 city_city_160 23 major_discipline_No Major 22 experience_14 22 major_discipline_Business Degree 21 city_city_67 20 experience_10 19 city_city_136 19 last_new_job_>4 18 major_discipline_Other 18 education_level_High School 17 last_new_job_3 17 experience_13 16 experience_16 16 experience_12 14 city_city_114 13 education_level_Primary School 13 experience_15 10 experience_6 10 city_city_102 9 experience_7 9 experience_5 6 city_city_104 6 experience_19 3 city_city_11 3 experience_1 3 city_city_142 0 city_city_74 0 city_city_141 0 city_city_144 0 city_city_9 0 city_city_138 0 city_city_175 0 city_city_26 0 city_city_109 0 city_city_70 0 city_city_77 0 city_city_94 0 city_city_149 0 city_city_2 0 city_city_75 0 city_city_64 0 city_city_159 0 education_level_Phd 0 city_city_18 0 city_city_166 0 city_city_81 0 city_city_45 0 city_city_13 0 city_city_55 0 city_city_146 0 city_city_37 0 city_city_28 0 city_city_123 0 city_city_116 0 city_city_76 0 city_city_72 0 major_discipline_Arts 0 city_city_57 0 city_city_73 0 city_city_12 0 city_city_33 0 city_city_59 0 city_city_126 0 city_city_25 0 city_city_7 0 city_city_118 0 city_city_107 0 city_city_139 0 city_city_176 0 city_city_179 0 city_city_39 0 experience_18 0 city_city_143 0 city_city_127 0 city_city_69 0 city_city_105 0 city_city_80 0 city_city_71 0 city_city_117 0 city_city_140 0 city_city_50 0 city_city_131 0 city_city_150 0 city_city_171 0 city_city_93 0 city_city_120 0 city_city_24 0 city_city_53 0 city_city_41 0 city_city_91 0 city_city_1 0 city_city_152 0 experience_20 0 city_city_54 0 city_city_133 0 city_city_145 0 city_city_173 0 city_city_82 0 city_city_106 0 city_city_78 0 city_city_98 0 city_city_27 0 city_city_157 0 city_city_180 0 city_city_79 0 city_city_99 0 city_city_8 0 city_city_40 0 city_city_31 0 city_city_19 0 city_city_46 0 city_city_42 0 city_city_128 0 city_city_20 0 city_city_97 0 city_city_36 0 city_city_101 0 city_city_44 0 city_city_134 0 city_city_10 0 gender_Other 0 city_city_115 0 company_type_Other 0 city_city_158 0 city_city_162 0 city_city_129 0 city_city_61 0 city_city_30 0 city_city_84 0 city_city_90 0 city_city_83 0 city_city_165 0 city_city_48 0 city_city_65 0 city_city_23 0 city_city_43 0 city_city_155 0 city_city_111 0 city_city_121 0 city_city_14 0 city_city_62 0 city_city_167 0 city_city_89 0 dtype: int32

plt.figure(figsize=(48, 5))
plt.plot(feature_importance)
plt.xticks(rotation=45)

plt.savefig('assets/Feature_Importance.png', bbox_inches='tight')

Scikit-Learn ML Model Explainability

Model Explainability

# fit the explainer
explainer = shap.Explainer(grid_search.best_estimator_.predict, X_test)
# calculate shap values
shap_values = explainer(X_test)

The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.

# visualize the first prediction's explanation
shap.plots.waterfall(shap_values[0])

Scikit-Learn ML Model Explainability

The mean absolute value of the SHAP values for each feature:

# show mean absolute effect
shap.plots.bar(shap_values)

Scikit-Learn ML Model Explainability

To get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low). This reveals for example that a high city_development_index lowers the predicted target - chance that an employee is searching for a new job.

shap.summary_plot(shap_values, plot_size=(12,12), max_display=30)

Scikit-Learn ML Model Explainability