vignettes/vignette_titanic.Rmd
vignette_titanic.RmdLet’s see an example for DALEX package for
classification models for the survival problem for Titanic dataset. Here
we are using a dataset titanic available in the
DALEX package. Note that this data was copied from the
stablelearner package.
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, not it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic_imputed,
classification = TRUE)
model_titanic_rf#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, classification = TRUE)
#>
#> Type: Classification
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 1
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error: 20.03 %
The third step (it’s optional but useful) is to create a
DALEX explainer for the random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed,
y = titanic_imputed$survived,
label = "Random Forest v7",
colorize = FALSE)#> Preparation of a new explainer is initiated
#> -> model label : Random Forest v7
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.14.1 , task classification ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.213865 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -1 , mean = 0.1082918 , max = 1
#> A new explainer has been created!
Use the variable_importance() explainer to present the
importance of particular features. Note that
type = "difference" normalizes dropouts, and now they all
start in 0.
vi_rf <- model_parts(explain_titanic_rf)
head(vi_rf)#> variable mean_dropout_loss label
#> 1 _full_model_ 0.2164656 Random Forest v7
#> 2 survived 0.2164564 Random Forest v7
#> 3 parch 0.2309825 Random Forest v7
#> 4 sibsp 0.2360850 Random Forest v7
#> 5 embarked 0.2387941 Random Forest v7
#> 6 fare 0.2621238 Random Forest v7
plot(vi_rf)
As we see the most important feature is Sex. Next three
important features are Pclass, Age and
Fare. Let’s see the link between model response and these
features.
Such univariate relation can be calculated with
variable_effect().
Kids 5 years old and younger have a much higher survival probability.
vr_age <- model_profile(explain_titanic_rf, variables = "age")
head(vr_age)#> $cp_profiles
#> Top profiles :
#> gender age class embarked fare sibsp parch survived
#> 2068 male 0.1666667 victualling crew Southampton 0 0 0 0
#> 2068.1 male 2.0000000 victualling crew Southampton 0 0 0 0
#> 2068.2 male 4.0000000 victualling crew Southampton 0 0 0 0
#> 2068.3 male 7.0000000 victualling crew Southampton 0 0 0 0
#> 2068.4 male 9.0000000 victualling crew Southampton 0 0 0 0
#> 2068.5 male 13.0000000 victualling crew Southampton 0 0 0 0
#> _yhat_ _vname_ _ids_ _label_
#> 2068 0 age 2068 Random Forest v7
#> 2068.1 0 age 2068 Random Forest v7
#> 2068.2 0 age 2068 Random Forest v7
#> 2068.3 0 age 2068 Random Forest v7
#> 2068.4 0 age 2068 Random Forest v7
#> 2068.5 0 age 2068 Random Forest v7
#>
#>
#> Top observations:
#> gender age class embarked fare sibsp parch survived
#> 2068 male 33 victualling crew Southampton 0.0000 0 0 0
#> 689 male 38 3rd Southampton 56.0911 0 0 1
#> 886 male 41 3rd Southampton 7.0206 0 0 0
#> 1177 male 74 3rd Southampton 7.1506 0 0 0
#> 1871 male 42 victualling crew Southampton 0.0000 0 0 0
#> 957 male 21 3rd Southampton 7.1806 0 0 0
#> _yhat_ _label_ _ids_
#> 2068 0 Random Forest v7 1
#> 689 1 Random Forest v7 2
#> 886 0 Random Forest v7 3
#> 1177 0 Random Forest v7 4
#> 1871 0 Random Forest v7 5
#> 957 0 Random Forest v7 6
#>
#> $agr_profiles
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 age Random Forest v7 0.1666667 0.50 0
#> 2 age Random Forest v7 2.0000000 0.56 0
#> 3 age Random Forest v7 4.0000000 0.56 0
#> 4 age Random Forest v7 7.0000000 0.49 0
#> 5 age Random Forest v7 9.0000000 0.40 0
#> 6 age Random Forest v7 13.0000000 0.19 0
#>
#> $color
#> [1] "#4378bf"
plot(vr_age)
Passengers in the first-class have much higher survival probability.
vr_class <- model_profile(explain_titanic_rf, variables = "class")
plot(vr_class)
Very cheap tickets are linked with lower chances.
vr_fare <- variable_profile(explain_titanic_rf, variables = "fare")
plot(vr_fare)
Passengers that embarked from C have the highest survival.
vr_embarked <- model_profile(explain_titanic_rf, variables = "embarked")
plot(vr_embarked)
Let’s see break-down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
It looks like the most important feature for this passenger is
age and sex. After all his odds for survival
are higher than for the average passenger. Mainly because of the young
age and despite being a male.
Let’s train more models for survival.
library("rms")
model_titanic_lmr <- lrm(survived ~ class + gender + rcs(age) + sibsp +
parch + fare + embarked, titanic_imputed)
explain_titanic_lmr <- explain(model_titanic_lmr, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, type = "fitted"),
label = "Logistic regression")#> Preparation of a new explainer is initiated
#> -> model label : Logistic regression
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, type = "fitted")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package rms , ver. 6.5.0 , task classification ( default )
#> -> predicted values : numerical, min = 0.002671631 , mean = 0.3221568 , max = 0.9845724
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9845724 , mean = -2.491758e-09 , max = 0.9715125
#> A new explainer has been created!
library("gbm")
model_titanic_gbm <- gbm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
n.trees = 15000)#> Distribution not specified, assuming bernoulli ...
explain_titanic_gbm <- explain(model_titanic_gbm, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, n.trees = 15000, type = "response"),
label = "Generalized Boosted Models",
colorize = FALSE)#> Preparation of a new explainer is initiated
#> -> model label : Generalized Boosted Models
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, n.trees = 15000, type = "response")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package gbm , ver. 2.1.8.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.0004911074 , mean = 0.3229161 , max = 0.9985837
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9657513 , mean = -0.0007592975 , max = 0.9865362
#> A new explainer has been created!
library("e1071")
model_titanic_svm <- svm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
type = "C-classification", probability = TRUE)
explain_titanic_svm <- explain(model_titanic_svm, data = titanic_imputed,
y = titanic_imputed$survived,
label = "Support Vector Machines",
colorize = FALSE)#> Preparation of a new explainer is initiated
#> -> model label : Support Vector Machines
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.svm will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package e1071 , ver. 1.7.13 , task classification ( default )
#> -> predicted values : numerical, min = 0.0890025 , mean = 0.3261746 , max = 0.9632474
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8679337 , mean = -0.004017876 , max = 0.9109975
#> A new explainer has been created!
library("caret")
model_titanic_knn <- knn3(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed, k = 5)
explain_titanic_knn <- explain(model_titanic_knn, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x) predict(m, x)[,2],
label = "k-Nearest Neighbors",
colorize = FALSE)#> Preparation of a new explainer is initiated
#> -> model label : k-Nearest Neighbors
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x)[, 2]
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package Model of class: knn3 package unrecognized , ver. Unknown , task regression ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.3061413 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8333333 , mean = 0.01601551 , max = 0.9285714
#> A new explainer has been created!
vi_rf <- model_parts(explain_titanic_rf)
vi_lmr <- model_parts(explain_titanic_lmr)
vi_gbm <- model_parts(explain_titanic_gbm)
vi_svm <- model_parts(explain_titanic_svm)
vi_knn <- model_parts(explain_titanic_knn)
plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)
vr_age_rf <- model_profile(explain_titanic_rf, variables = "age")
vr_age_lmr <- model_profile(explain_titanic_lmr, variables = "age")
vr_age_gbm <- model_profile(explain_titanic_gbm, variables = "age")
vr_age_svm <- model_profile(explain_titanic_svm, variables = "age")
vr_age_knn <- model_profile(explain_titanic_knn, variables = "age")
plot(vr_age_rf$agr_profiles,
vr_age_lmr$agr_profiles,
vr_age_gbm$agr_profiles,
vr_age_svm$agr_profiles,
vr_age_knn$agr_profiles)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
sp_lmr <- predict_parts(explain_titanic_lmr, new_passanger)
plot(sp_lmr)
sp_gbm <- predict_parts(explain_titanic_gbm, new_passanger)
plot(sp_gbm)
sp_svm <- predict_parts(explain_titanic_svm, new_passanger)
plot(sp_svm)
sp_knn <- predict_parts(explain_titanic_knn, new_passanger)
plot(sp_knn)
#> R version 4.2.3 (2023-03-15)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Big Sur ... 10.16
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] caret_6.0-94 e1071_1.7-13 gbm_2.1.8.1 rms_6.5-0
#> [5] SparseM_1.81 ggplot2_3.4.1 lattice_0.20-45 survival_3.5-3
#> [9] Hmisc_5.0-1 ranger_0.14.1 DALEX_2.5.1
#>
#> loaded via a namespace (and not attached):
#> [1] nlme_3.1-162 fs_1.6.1 lubridate_1.9.2
#> [4] rprojroot_2.0.3 tools_4.2.3 backports_1.4.1
#> [7] bslib_0.4.2 utf8_1.2.3 R6_2.5.1
#> [10] rpart_4.1.19 colorspace_2.1-0 nnet_7.3-18
#> [13] withr_2.5.0 tidyselect_1.2.0 gridExtra_2.3
#> [16] compiler_4.2.3 textshaping_0.3.6 cli_3.6.0
#> [19] quantreg_5.94 htmlTable_2.4.1 desc_1.4.2
#> [22] sandwich_3.0-2 labeling_0.4.2 sass_0.4.5
#> [25] scales_1.2.1 checkmate_2.1.0 polspline_1.1.22
#> [28] mvtnorm_1.1-3 proxy_0.4-27 pkgdown_2.0.7
#> [31] systemfonts_1.0.4 stringr_1.5.0 digest_0.6.31
#> [34] foreign_0.8-84 ingredients_2.3.0 rmarkdown_2.20
#> [37] iBreakDown_2.0.1 base64enc_0.1-3 pkgconfig_2.0.3
#> [40] htmltools_0.5.4 parallelly_1.34.0 fastmap_1.1.1
#> [43] highr_0.10 htmlwidgets_1.6.2 rlang_1.1.0
#> [46] rstudioapi_0.14 jquerylib_0.1.4 farver_2.1.1
#> [49] generics_0.1.3 zoo_1.8-11 jsonlite_1.8.4
#> [52] ModelMetrics_1.2.2.2 dplyr_1.1.1 magrittr_2.0.3
#> [55] Formula_1.2-5 Matrix_1.5-3 Rcpp_1.0.10
#> [58] munsell_0.5.0 fansi_1.0.4 lifecycle_1.0.3
#> [61] pROC_1.18.0 stringi_1.7.12 multcomp_1.4-23
#> [64] yaml_2.3.7 MASS_7.3-58.2 plyr_1.8.8
#> [67] recipes_1.0.5 grid_4.2.3 parallel_4.2.3
#> [70] listenv_0.9.0 splines_4.2.3 knitr_1.42
#> [73] pillar_1.9.0 stats4_4.2.3 reshape2_1.4.4
#> [76] future.apply_1.10.0 codetools_0.2-19 glue_1.6.2
#> [79] evaluate_0.20 data.table_1.14.8 vctrs_0.6.1
#> [82] foreach_1.5.2 MatrixModels_0.5-1 gtable_0.3.3
#> [85] purrr_1.0.1 future_1.32.0 cachem_1.0.7
#> [88] gower_1.0.1 xfun_0.37 prodlim_2019.11.13
#> [91] ragg_1.2.5 class_7.3-21 timeDate_4022.108
#> [94] tibble_3.2.1 iterators_1.0.14 hardhat_1.2.0
#> [97] memoise_2.0.1 lava_1.7.2.1 cluster_2.1.4
#> [100] timechange_0.2.0 globals_0.16.2 TH.data_1.1-1
#> [103] ipred_0.9-14