vignettes/vignette_titanic.Rmd
vignette_titanic.Rmd
Let’s see an example for DALEX
package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic
available in the DALEX
package. Note that this data was copied from the stablelearner
package.
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, not it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic_imputed,
classification = TRUE)
model_titanic_rf
#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, classification = TRUE)
#>
#> Type: Classification
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 1
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error: 19.30 %
The third step (it’s optional but useful) is to create a DALEX
explainer for the random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed,
y = titanic_imputed$survived,
label = "Random Forest v7",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest v7
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.13.1 , task classification ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.2152243 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -1 , mean = 0.1078387 , max = 1
#> A new explainer has been created!
Use the variable_importance()
explainer to present the importance of particular features. Note that type = "difference"
normalizes dropouts, and now they all start in 0.
vi_rf <- model_parts(explain_titanic_rf)
head(vi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.2183236 Random Forest v7
#> 2 survived 0.2189447 Random Forest v7
#> 3 parch 0.2317355 Random Forest v7
#> 4 sibsp 0.2332329 Random Forest v7
#> 5 embarked 0.2368054 Random Forest v7
#> 6 fare 0.2604267 Random Forest v7
plot(vi_rf)
As we see the most important feature is Sex
. Next three important features are Pclass
, Age
and Fare
. Let’s see the link between model response and these features.
Such univariate relation can be calculated with variable_effect()
.
Kids 5 years old and younger have a much higher survival probability.
vr_age <- model_profile(explain_titanic_rf, variables = "age")
head(vr_age)
#> $cp_profiles
#> Top profiles :
#> gender age class embarked fare sibsp parch survived _yhat_
#> 128 male 0.1666667 1st Cherbourg 31 0 0 1 1
#> 128.1 male 2.0000000 1st Cherbourg 31 0 0 1 1
#> 128.2 male 4.0000000 1st Cherbourg 31 0 0 1 1
#> 128.3 male 7.0000000 1st Cherbourg 31 0 0 1 1
#> 128.4 male 9.0000000 1st Cherbourg 31 0 0 1 1
#> 128.5 male 13.0000000 1st Cherbourg 31 0 0 1 1
#> _vname_ _ids_ _label_
#> 128 age 128 Random Forest v7
#> 128.1 age 128 Random Forest v7
#> 128.2 age 128 Random Forest v7
#> 128.3 age 128 Random Forest v7
#> 128.4 age 128 Random Forest v7
#> 128.5 age 128 Random Forest v7
#>
#>
#> Top observations:
#> gender age class embarked fare sibsp parch survived
#> 128 male 39 1st Cherbourg 31.0000 0 0 1
#> 2070 female 22 victualling crew Southampton 0.0000 0 0 0
#> 894 male 60 3rd Southampton 6.0409 0 0 0
#> 1973 male 50 victualling crew Belfast 0.0000 0 0 0
#> 683 male 31 2nd Southampton 10.1000 0 0 0
#> 624 male 46 1st Southampton 26.0000 0 0 0
#> _yhat_ _label_ _ids_
#> 128 0 Random Forest v7 1
#> 2070 1 Random Forest v7 2
#> 894 0 Random Forest v7 3
#> 1973 0 Random Forest v7 4
#> 683 0 Random Forest v7 5
#> 624 0 Random Forest v7 6
#>
#> $agr_profiles
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 age Random Forest v7 0.1666667 0.44 0
#> 2 age Random Forest v7 2.0000000 0.55 0
#> 3 age Random Forest v7 4.0000000 0.55 0
#> 4 age Random Forest v7 7.0000000 0.44 0
#> 5 age Random Forest v7 9.0000000 0.40 0
#> 6 age Random Forest v7 13.0000000 0.30 0
#>
#> $color
#> [1] "#4378bf"
plot(vr_age)
Passengers in the first-class have much higher survival probability.
vr_class <- model_profile(explain_titanic_rf, variables = "class")
plot(vr_class)
Very cheap tickets are linked with lower chances.
vr_fare <- variable_profile(explain_titanic_rf, variables = "fare")
plot(vr_fare)
Passengers that embarked from C have the highest survival.
vr_embarked <- model_profile(explain_titanic_rf, variables = "embarked")
plot(vr_embarked)
Let’s see break-down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
It looks like the most important feature for this passenger is age
and sex
. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite being a male.
Let’s train more models for survival.
library("rms")
model_titanic_lmr <- lrm(survived ~ class + gender + rcs(age) + sibsp +
parch + fare + embarked, titanic_imputed)
explain_titanic_lmr <- explain(model_titanic_lmr, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, type = "fitted"),
label = "Logistic regression")
#> Preparation of a new explainer is initiated
#> -> model label : Logistic regression
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, type = "fitted")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package rms , ver. 6.3.0 , task classification ( default )
#> -> predicted values : numerical, min = 0.002671631 , mean = 0.3221568 , max = 0.9845724
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9845724 , mean = -2.491758e-09 , max = 0.9715125
#> A new explainer has been created!
library("gbm")
model_titanic_gbm <- gbm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
n.trees = 15000)
#> Distribution not specified, assuming bernoulli ...
explain_titanic_gbm <- explain(model_titanic_gbm, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, n.trees = 15000, type = "response"),
label = "Generalized Boosted Models",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Generalized Boosted Models
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, n.trees = 15000, type = "response")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package gbm , ver. 2.1.8 , task classification ( default )
#> -> predicted values : numerical, min = 0.0003164416 , mean = 0.3201063 , max = 0.9989325
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9658806 , mean = 0.002050489 , max = 0.9903356
#> A new explainer has been created!
library("e1071")
model_titanic_svm <- svm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
type = "C-classification", probability = TRUE)
explain_titanic_svm <- explain(model_titanic_svm, data = titanic_imputed,
y = titanic_imputed$survived,
label = "Support Vector Machines",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Support Vector Machines
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.svm will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package e1071 , ver. 1.7.9 , task classification ( default )
#> -> predicted values : numerical, min = 0.08688823 , mean = 0.3245923 , max = 0.9643056
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8697901 , mean = -0.002435489 , max = 0.9131118
#> A new explainer has been created!
library("caret")
model_titanic_knn <- knn3(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed, k = 5)
explain_titanic_knn <- explain(model_titanic_knn, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x) predict(m, x)[,2],
label = "k-Nearest Neighbors",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : k-Nearest Neighbors
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x)[, 2]
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package Model of class: knn3 package unrecognized , ver. Unknown , task regression ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.3061413 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8333333 , mean = 0.01601551 , max = 0.9285714
#> A new explainer has been created!
vi_rf <- model_parts(explain_titanic_rf)
vi_lmr <- model_parts(explain_titanic_lmr)
vi_gbm <- model_parts(explain_titanic_gbm)
vi_svm <- model_parts(explain_titanic_svm)
vi_knn <- model_parts(explain_titanic_knn)
plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)
vr_age_rf <- model_profile(explain_titanic_rf, variables = "age")
vr_age_lmr <- model_profile(explain_titanic_lmr, variables = "age")
vr_age_gbm <- model_profile(explain_titanic_gbm, variables = "age")
vr_age_svm <- model_profile(explain_titanic_svm, variables = "age")
vr_age_knn <- model_profile(explain_titanic_knn, variables = "age")
plot(vr_age_rf$agr_profiles,
vr_age_lmr$agr_profiles,
vr_age_gbm$agr_profiles,
vr_age_svm$agr_profiles,
vr_age_knn$agr_profiles)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
sp_lmr <- predict_parts(explain_titanic_lmr, new_passanger)
plot(sp_lmr)
sp_gbm <- predict_parts(explain_titanic_gbm, new_passanger)
plot(sp_gbm)
sp_svm <- predict_parts(explain_titanic_svm, new_passanger)
plot(sp_svm)
sp_knn <- predict_parts(explain_titanic_knn, new_passanger)
plot(sp_knn)
#> R version 4.2.0 (2022-04-22)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Big Sur/Monterey 10.16
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] caret_6.0-92 e1071_1.7-9 gbm_2.1.8 rms_6.3-0
#> [5] SparseM_1.81 Hmisc_4.7-0 ggplot2_3.3.6 Formula_1.2-4
#> [9] survival_3.3-1 lattice_0.20-45 ranger_0.13.1 DALEX_2.4.1
#>
#> loaded via a namespace (and not attached):
#> [1] TH.data_1.1-1 colorspace_2.0-3 ellipsis_0.3.2
#> [4] class_7.3-20 rprojroot_2.0.3 htmlTable_2.4.0
#> [7] base64enc_0.1-3 fs_1.5.2 rstudioapi_0.13
#> [10] proxy_0.4-26 listenv_0.8.0 farver_2.1.0
#> [13] MatrixModels_0.5-0 prodlim_2019.11.13 fansi_1.0.3
#> [16] mvtnorm_1.1-3 lubridate_1.8.0 iBreakDown_2.0.1
#> [19] codetools_0.2-18 splines_4.2.0 cachem_1.0.6
#> [22] knitr_1.39 jsonlite_1.8.0 pROC_1.18.0
#> [25] cluster_2.1.3 png_0.1-7 compiler_4.2.0
#> [28] backports_1.4.1 Matrix_1.4-1 fastmap_1.1.0
#> [31] cli_3.3.0 htmltools_0.5.2 quantreg_5.93
#> [34] tools_4.2.0 gtable_0.3.0 glue_1.6.2
#> [37] reshape2_1.4.4 dplyr_1.0.9 Rcpp_1.0.8.3
#> [40] jquerylib_0.1.4 pkgdown_2.0.3 vctrs_0.4.1
#> [43] nlme_3.1-157 iterators_1.0.14 timeDate_3043.102
#> [46] xfun_0.31 gower_1.0.0 stringr_1.4.0
#> [49] globals_0.15.0 lifecycle_1.0.1 future_1.25.0
#> [52] polspline_1.1.20 MASS_7.3-56 zoo_1.8-10
#> [55] scales_1.2.0 ipred_0.9-12 ragg_1.2.2
#> [58] parallel_4.2.0 sandwich_3.0-1 RColorBrewer_1.1-3
#> [61] yaml_2.3.5 memoise_2.0.1 gridExtra_2.3
#> [64] sass_0.4.1 rpart_4.1.16 latticeExtra_0.6-29
#> [67] stringi_1.7.6 highr_0.9 desc_1.4.1
#> [70] foreach_1.5.2 checkmate_2.1.0 hardhat_0.2.0
#> [73] lava_1.6.10 rlang_1.0.2 pkgconfig_2.0.3
#> [76] systemfonts_1.0.4 evaluate_0.15 purrr_0.3.4
#> [79] recipes_0.2.0 htmlwidgets_1.5.4 labeling_0.4.2
#> [82] tidyselect_1.1.2 parallelly_1.31.1 plyr_1.8.7
#> [85] magrittr_2.0.3 R6_2.5.1 generics_0.1.2
#> [88] multcomp_1.4-19 pillar_1.7.0 foreign_0.8-82
#> [91] withr_2.5.0 nnet_7.3-17 tibble_3.1.7
#> [94] future.apply_1.9.0 crayon_1.5.1 utf8_1.2.2
#> [97] rmarkdown_2.14 jpeg_0.1-9 ingredients_2.2.0
#> [100] grid_4.2.0 data.table_1.14.2 ModelMetrics_1.2.2.2
#> [103] digest_0.6.29 textshaping_0.3.6 stats4_4.2.0
#> [106] munsell_0.5.0 bslib_0.3.1