vignettes/vignette_titanic.Rmd
vignette_titanic.Rmd
Let’s see an example for DALEX
package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic
available in the DALEX
package. Note that this data was copied from the stablelearner
package.
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, not it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic_imputed,
classification = TRUE)
model_titanic_rf
#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, classification = TRUE)
#>
#> Type: Classification
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 1
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error: 19.76 %
The third step (it’s optional but useful) is to create a DALEX
explainer for the random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed,
y = titanic_imputed$survived,
label = "Random Forest v7",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest v7
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.12.1 , task classification ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.2093339 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -1 , mean = 0.1128228 , max = 1
#> A new explainer has been created!
Use the variable_importance()
explainer to present the importance of particular features. Note that type = "difference"
normalizes dropouts, and now they all start in 0.
vi_rf <- model_parts(explain_titanic_rf)
head(vi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.2084722 Random Forest v7
#> 2 survived 0.2084245 Random Forest v7
#> 3 sibsp 0.2255415 Random Forest v7
#> 4 parch 0.2260147 Random Forest v7
#> 5 embarked 0.2302427 Random Forest v7
#> 6 fare 0.2537347 Random Forest v7
plot(vi_rf)
As we see the most important feature is Sex
. Next three important features are Pclass
, Age
and Fare
. Let’s see the link between model response and these features.
Such univariate relation can be calculated with variable_effect()
.
Kids 5 years old and younger have a much higher survival probability.
vr_age <- model_profile(explain_titanic_rf, variables = "age")
head(vr_age)
#> $cp_profiles
#> Top profiles :
#> gender age class embarked fare sibsp parch survived _yhat_
#> 877 male 0.1666667 2nd Southampton 36.15 1 1 0 1
#> 877.1 male 2.0000000 2nd Southampton 36.15 1 1 0 1
#> 877.2 male 4.0000000 2nd Southampton 36.15 1 1 0 1
#> 877.3 male 7.0000000 2nd Southampton 36.15 1 1 0 1
#> 877.4 male 9.0000000 2nd Southampton 36.15 1 1 0 1
#> 877.5 male 13.0000000 2nd Southampton 36.15 1 1 0 0
#> _vname_ _ids_ _label_
#> 877 age 877 Random Forest v7
#> 877.1 age 877 Random Forest v7
#> 877.2 age 877 Random Forest v7
#> 877.3 age 877 Random Forest v7
#> 877.4 age 877 Random Forest v7
#> 877.5 age 877 Random Forest v7
#>
#>
#> Top observations:
#> gender age class embarked fare sibsp parch survived
#> 877 male 19 2nd Southampton 36.1500 1 1 0
#> 1956 male 30 engineering crew Southampton 0.0000 0 0 1
#> 1125 male 42 3rd Southampton 7.1711 0 0 0
#> 727 male 21 3rd Queenstown 7.1707 0 0 0
#> 1381 female 33 victualling crew Southampton 0.0000 0 0 1
#> 76 female 3 3rd Cherbourg 19.0502 2 1 1
#> _yhat_ _label_ _ids_
#> 877 0 Random Forest v7 1
#> 1956 0 Random Forest v7 2
#> 1125 0 Random Forest v7 3
#> 727 0 Random Forest v7 4
#> 1381 1 Random Forest v7 5
#> 76 1 Random Forest v7 6
#>
#> $agr_profiles
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 age Random Forest v7 0.1666667 0.29 0
#> 2 age Random Forest v7 2.0000000 0.52 0
#> 3 age Random Forest v7 4.0000000 0.52 0
#> 4 age Random Forest v7 7.0000000 0.42 0
#> 5 age Random Forest v7 9.0000000 0.34 0
#> 6 age Random Forest v7 13.0000000 0.20 0
#>
#> $color
#> [1] "#4378bf"
plot(vr_age)
Passengers in the first-class have much higher survival probability.
vr_class <- model_profile(explain_titanic_rf, variables = "class")
plot(vr_class)
Very cheap tickets are linked with lower chances.
vr_fare <- variable_profile(explain_titanic_rf, variables = "fare")
plot(vr_fare)
Passengers that embarked from C have the highest survival.
vr_embarked <- model_profile(explain_titanic_rf, variables = "embarked")
plot(vr_embarked)
Let’s see break-down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
It looks like the most important feature for this passenger is age
and sex
. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite being a male.
Let’s train more models for survival.
library("rms")
model_titanic_lmr <- lrm(survived ~ class + gender + rcs(age) + sibsp +
parch + fare + embarked, titanic_imputed)
explain_titanic_lmr <- explain(model_titanic_lmr, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, type = "fitted"),
label = "Logistic regression")
#> Preparation of a new explainer is initiated
#> -> model label : Logistic regression
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, type = "fitted")
#> -> predicted values : No value for predict function target column. ( [33m default [39m )
#> -> model_info : package rms , ver. 6.1.1 , task classification ( [33m default [39m )
#> -> predicted values : numerical, min = 0.002671631 , mean = 0.3221568 , max = 0.9845724
#> -> residual function : difference between y and yhat ( [33m default [39m )
#> -> residuals : numerical, min = -0.9845724 , mean = -2.491758e-09 , max = 0.9715125
#> [32m A new explainer has been created! [39m
library("gbm")
model_titanic_gbm <- gbm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
n.trees = 15000)
#> Distribution not specified, assuming bernoulli ...
explain_titanic_gbm <- explain(model_titanic_gbm, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, n.trees = 15000, type = "response"),
label = "Generalized Boosted Models",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Generalized Boosted Models
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, n.trees = 15000, type = "response")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package gbm , ver. 2.1.8 , task classification ( default )
#> -> predicted values : numerical, min = 0.0006051669 , mean = 0.3229494 , max = 0.9988031
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9617458 , mean = -0.000792653 , max = 0.9916582
#> A new explainer has been created!
library("e1071")
model_titanic_svm <- svm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
type = "C-classification", probability = TRUE)
explain_titanic_svm <- explain(model_titanic_svm, data = titanic_imputed,
y = titanic_imputed$survived,
label = "Support Vector Machines",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Support Vector Machines
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.svm will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package e1071 , ver. 1.7.4 , task classification ( default )
#> -> predicted values : numerical, min = 0.08791044 , mean = 0.3238389 , max = 0.9625149
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8657667 , mean = -0.001682127 , max = 0.9120896
#> A new explainer has been created!
library("caret")
model_titanic_knn <- knn3(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed, k = 5)
explain_titanic_knn <- explain(model_titanic_knn, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x) predict(m, x)[,2],
label = "k-Nearest Neighbors",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : k-Nearest Neighbors
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x)[, 2]
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package Model of class: knn3 package unrecognized , ver. Unknown , task regression ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.3061413 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8333333 , mean = 0.01601551 , max = 0.9285714
#> A new explainer has been created!
vi_rf <- model_parts(explain_titanic_rf)
vi_lmr <- model_parts(explain_titanic_lmr)
vi_gbm <- model_parts(explain_titanic_gbm)
vi_svm <- model_parts(explain_titanic_svm)
vi_knn <- model_parts(explain_titanic_knn)
plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)
vr_age_rf <- model_profile(explain_titanic_rf, variables = "age")
vr_age_lmr <- model_profile(explain_titanic_lmr, variables = "age")
vr_age_gbm <- model_profile(explain_titanic_gbm, variables = "age")
vr_age_svm <- model_profile(explain_titanic_svm, variables = "age")
vr_age_knn <- model_profile(explain_titanic_knn, variables = "age")
plot(vr_age_rf$agr_profiles,
vr_age_lmr$agr_profiles,
vr_age_gbm$agr_profiles,
vr_age_svm$agr_profiles,
vr_age_knn$agr_profiles)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
sp_lmr <- predict_parts(explain_titanic_lmr, new_passanger)
plot(sp_lmr)
sp_gbm <- predict_parts(explain_titanic_gbm, new_passanger)
plot(sp_gbm)
sp_svm <- predict_parts(explain_titanic_svm, new_passanger)
plot(sp_svm)
sp_knn <- predict_parts(explain_titanic_knn, new_passanger)
plot(sp_knn)
#> R version 4.0.4 (2021-02-15)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Catalina 10.15.7
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] caret_6.0-86 e1071_1.7-4 gbm_2.1.8 rms_6.1-1
#> [5] SparseM_1.81 Hmisc_4.4-2 ggplot2_3.3.3 Formula_1.2-4
#> [9] survival_3.2-7 lattice_0.20-41 ranger_0.12.1 DALEX_2.1.1
#>
#> loaded via a namespace (and not attached):
#> [1] nlme_3.1-152 matrixStats_0.58.0 fs_1.5.0
#> [4] lubridate_1.7.9.2 RColorBrewer_1.1-2 rprojroot_2.0.2
#> [7] tools_4.0.4 backports_1.2.1 R6_2.5.0
#> [10] rpart_4.1-15 colorspace_2.0-0 nnet_7.3-15
#> [13] withr_2.4.1 tidyselect_1.1.0 gridExtra_2.3
#> [16] compiler_4.0.4 textshaping_0.3.0 quantreg_5.83
#> [19] htmlTable_2.1.0 desc_1.2.0 sandwich_3.0-0
#> [22] labeling_0.4.2 scales_1.1.1 checkmate_2.0.0
#> [25] polspline_1.1.19 mvtnorm_1.1-1 pkgdown_1.6.1
#> [28] systemfonts_1.0.1 stringr_1.4.0 digest_0.6.27
#> [31] foreign_0.8-81 ingredients_2.0.1 rmarkdown_2.7
#> [34] iBreakDown_1.3.1 base64enc_0.1-3 jpeg_0.1-8.1
#> [37] pkgconfig_2.0.3 htmltools_0.5.1.1 fastmap_1.1.0
#> [40] highr_0.8 htmlwidgets_1.5.3 rlang_0.4.10
#> [43] rstudioapi_0.13 farver_2.0.3 generics_0.1.0
#> [46] zoo_1.8-8 ModelMetrics_1.2.2.2 dplyr_1.0.4
#> [49] magrittr_2.0.1 Matrix_1.3-2 Rcpp_1.0.6
#> [52] munsell_0.5.0 lifecycle_1.0.0 pROC_1.17.0.1
#> [55] stringi_1.5.3 multcomp_1.4-16 yaml_2.2.1
#> [58] MASS_7.3-53 plyr_1.8.6 recipes_0.1.15
#> [61] grid_4.0.4 crayon_1.4.1 splines_4.0.4
#> [64] knitr_1.31 pillar_1.4.7 stats4_4.0.4
#> [67] reshape2_1.4.4 codetools_0.2-18 glue_1.4.2
#> [70] evaluate_0.14 latticeExtra_0.6-29 data.table_1.13.6
#> [73] png_0.1-7 vctrs_0.3.6 foreach_1.5.1
#> [76] MatrixModels_0.4-1 gtable_0.3.0 purrr_0.3.4
#> [79] assertthat_0.2.1 cachem_1.0.4 xfun_0.21
#> [82] gower_0.2.2 prodlim_2019.11.13 ragg_1.1.0
#> [85] class_7.3-18 timeDate_3043.102 tibble_3.0.6
#> [88] conquer_1.0.2 iterators_1.0.13 memoise_2.0.0
#> [91] lava_1.6.8.1 cluster_2.1.0 TH.data_1.0-10
#> [94] ellipsis_0.3.1 ipred_0.9-9