vignettes/vignette_titanic.Rmd
vignette_titanic.Rmd
Let’s see an example for DALEX
package for
classification models for the survival problem for Titanic dataset. Here
we are using a dataset titanic
available in the
DALEX
package. Note that this data was copied from the
stablelearner
package.
#> gender age class embarked fare sibsp parch survived
#> 1 male 42 3rd Southampton 7.11 0 0 0
#> 2 male 13 3rd Southampton 20.05 0 2 0
#> 3 male 16 3rd Southampton 20.05 1 1 0
#> 4 female 39 3rd Southampton 20.05 1 1 1
#> 5 female 16 3rd Southampton 7.13 0 0 1
#> 6 male 25 3rd Southampton 7.13 0 0 1
Ok, not it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic_imputed,
classification = TRUE)
model_titanic_rf
#> Ranger result
#>
#> Call:
#> ranger(survived ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic_imputed, classification = TRUE)
#>
#> Type: Classification
#> Number of trees: 500
#> Sample size: 2207
#> Number of independent variables: 7
#> Mtry: 2
#> Target node size: 1
#> Variable importance mode: none
#> Splitrule: gini
#> OOB prediction error: 19.62 %
The third step (it’s optional but useful) is to create a
DALEX
explainer for the random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic_imputed,
y = titanic_imputed$survived,
label = "Random Forest v7",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest v7
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.ranger will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package ranger , ver. 0.14.1 , task classification ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.2134119 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -1 , mean = 0.1087449 , max = 1
#> A new explainer has been created!
Use the variable_importance()
explainer to present the
importance of particular features. Note that
type = "difference"
normalizes dropouts, and now they all
start in 0.
vi_rf <- model_parts(explain_titanic_rf)
head(vi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.2124315 Random Forest v7
#> 2 survived 0.2121247 Random Forest v7
#> 3 parch 0.2276852 Random Forest v7
#> 4 sibsp 0.2308135 Random Forest v7
#> 5 embarked 0.2361444 Random Forest v7
#> 6 fare 0.2579629 Random Forest v7
plot(vi_rf)
As we see the most important feature is Sex
. Next three
important features are Pclass
, Age
and
Fare
. Let’s see the link between model response and these
features.
Such univariate relation can be calculated with
variable_effect()
.
Kids 5 years old and younger have a much higher survival probability.
vr_age <- model_profile(explain_titanic_rf, variables = "age")
head(vr_age)
#> $cp_profiles
#> Top profiles :
#> gender age class embarked fare sibsp parch survived _yhat_
#> 1190 male 0.1666667 2nd Southampton 22 0 0 0 1
#> 1190.1 male 2.0000000 2nd Southampton 22 0 0 0 1
#> 1190.2 male 4.0000000 2nd Southampton 22 0 0 0 1
#> 1190.3 male 7.0000000 2nd Southampton 22 0 0 0 0
#> 1190.4 male 9.0000000 2nd Southampton 22 0 0 0 0
#> 1190.5 male 13.0000000 2nd Southampton 22 0 0 0 0
#> _vname_ _ids_ _label_
#> 1190 age 1190 Random Forest v7
#> 1190.1 age 1190 Random Forest v7
#> 1190.2 age 1190 Random Forest v7
#> 1190.3 age 1190 Random Forest v7
#> 1190.4 age 1190 Random Forest v7
#> 1190.5 age 1190 Random Forest v7
#>
#>
#> Top observations:
#> gender age class embarked fare sibsp parch survived _yhat_
#> 1190 male 40 2nd Southampton 22.00 0 0 0 0
#> 1695 male 23 engineering crew Southampton 0.00 0 0 0 0
#> 1856 male 28 engineering crew Southampton 0.00 0 0 0 0
#> 11 male 30 3rd Southampton 7.05 0 0 0 0
#> 571 male 63 2nd Southampton 26.00 1 0 0 0
#> 1258 female 12 2nd Southampton 15.15 0 0 1 1
#> _label_ _ids_
#> 1190 Random Forest v7 1
#> 1695 Random Forest v7 2
#> 1856 Random Forest v7 3
#> 11 Random Forest v7 4
#> 571 Random Forest v7 5
#> 1258 Random Forest v7 6
#>
#> $agr_profiles
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 age Random Forest v7 0.1666667 0.55 0
#> 2 age Random Forest v7 2.0000000 0.58 0
#> 3 age Random Forest v7 4.0000000 0.59 0
#> 4 age Random Forest v7 7.0000000 0.48 0
#> 5 age Random Forest v7 9.0000000 0.46 0
#> 6 age Random Forest v7 13.0000000 0.36 0
#>
#> $color
#> [1] "#4378bf"
plot(vr_age)
Passengers in the first-class have much higher survival probability.
vr_class <- model_profile(explain_titanic_rf, variables = "class")
plot(vr_class)
Very cheap tickets are linked with lower chances.
vr_fare <- variable_profile(explain_titanic_rf, variables = "fare")
plot(vr_fare)
Passengers that embarked from C have the highest survival.
vr_embarked <- model_profile(explain_titanic_rf, variables = "embarked")
plot(vr_embarked)
Let’s see break-down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
It looks like the most important feature for this passenger is
age
and sex
. After all his odds for survival
are higher than for the average passenger. Mainly because of the young
age and despite being a male.
Let’s train more models for survival.
library("rms")
model_titanic_lmr <- lrm(survived ~ class + gender + rcs(age) + sibsp +
parch + fare + embarked, titanic_imputed)
explain_titanic_lmr <- explain(model_titanic_lmr, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, type = "fitted"),
label = "Logistic regression")
#> Preparation of a new explainer is initiated
#> -> model label : Logistic regression
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, type = "fitted")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package rms , ver. 6.4.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.002671631 , mean = 0.3221568 , max = 0.9845724
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9845724 , mean = -2.491758e-09 , max = 0.9715125
#> A new explainer has been created!
library("gbm")
model_titanic_gbm <- gbm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
n.trees = 15000)
#> Distribution not specified, assuming bernoulli ...
explain_titanic_gbm <- explain(model_titanic_gbm, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x)
predict(m, x, n.trees = 15000, type = "response"),
label = "Generalized Boosted Models",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Generalized Boosted Models
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x, n.trees = 15000, type = "response")
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package gbm , ver. 2.1.8.1 , task classification ( default )
#> -> predicted values : numerical, min = 0.0003540365 , mean = 0.3225163 , max = 0.9986777
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.9601845 , mean = -0.0003595284 , max = 0.9934245
#> A new explainer has been created!
library("e1071")
model_titanic_svm <- svm(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed,
type = "C-classification", probability = TRUE)
explain_titanic_svm <- explain(model_titanic_svm, data = titanic_imputed,
y = titanic_imputed$survived,
label = "Support Vector Machines",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : Support Vector Machines
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : yhat.svm will be used ( default )
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package e1071 , ver. 1.7.12 , task classification ( default )
#> -> predicted values : numerical, min = 0.08787756 , mean = 0.3244583 , max = 0.963078
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.867124 , mean = -0.002301487 , max = 0.9121224
#> A new explainer has been created!
library("caret")
model_titanic_knn <- knn3(survived ~ class + gender + age + sibsp +
parch + fare + embarked, data = titanic_imputed, k = 5)
explain_titanic_knn <- explain(model_titanic_knn, data = titanic_imputed,
y = titanic_imputed$survived,
predict_function = function(m,x) predict(m, x)[,2],
label = "k-Nearest Neighbors",
colorize = FALSE)
#> Preparation of a new explainer is initiated
#> -> model label : k-Nearest Neighbors
#> -> data : 2207 rows 8 cols
#> -> target variable : 2207 values
#> -> predict function : function(m, x) predict(m, x)[, 2]
#> -> predicted values : No value for predict function target column. ( default )
#> -> model_info : package Model of class: knn3 package unrecognized , ver. Unknown , task regression ( default )
#> -> predicted values : numerical, min = 0 , mean = 0.3061413 , max = 1
#> -> residual function : difference between y and yhat ( default )
#> -> residuals : numerical, min = -0.8333333 , mean = 0.01601551 , max = 0.9285714
#> A new explainer has been created!
vi_rf <- model_parts(explain_titanic_rf)
vi_lmr <- model_parts(explain_titanic_lmr)
vi_gbm <- model_parts(explain_titanic_gbm)
vi_svm <- model_parts(explain_titanic_svm)
vi_knn <- model_parts(explain_titanic_knn)
plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)
vr_age_rf <- model_profile(explain_titanic_rf, variables = "age")
vr_age_lmr <- model_profile(explain_titanic_lmr, variables = "age")
vr_age_gbm <- model_profile(explain_titanic_gbm, variables = "age")
vr_age_svm <- model_profile(explain_titanic_svm, variables = "age")
vr_age_knn <- model_profile(explain_titanic_knn, variables = "age")
plot(vr_age_rf$agr_profiles,
vr_age_lmr$agr_profiles,
vr_age_gbm$agr_profiles,
vr_age_svm$agr_profiles,
vr_age_knn$agr_profiles)
sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)
sp_lmr <- predict_parts(explain_titanic_lmr, new_passanger)
plot(sp_lmr)
sp_gbm <- predict_parts(explain_titanic_gbm, new_passanger)
plot(sp_gbm)
sp_svm <- predict_parts(explain_titanic_svm, new_passanger)
plot(sp_svm)
sp_knn <- predict_parts(explain_titanic_knn, new_passanger)
plot(sp_knn)
#> R version 4.2.2 (2022-10-31)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Big Sur ... 10.16
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] caret_6.0-93 e1071_1.7-12 gbm_2.1.8.1 rms_6.4-1
#> [5] SparseM_1.81 Hmisc_4.7-2 ggplot2_3.4.0 Formula_1.2-4
#> [9] survival_3.4-0 lattice_0.20-45 ranger_0.14.1 DALEX_2.5.0
#>
#> loaded via a namespace (and not attached):
#> [1] TH.data_1.1-1 colorspace_2.1-0 deldir_1.0-6
#> [4] class_7.3-20 rprojroot_2.0.3 htmlTable_2.4.1
#> [7] base64enc_0.1-3 fs_1.6.0 rstudioapi_0.14
#> [10] proxy_0.4-27 listenv_0.9.0 farver_2.1.1
#> [13] MatrixModels_0.5-1 prodlim_2019.11.13 fansi_1.0.4
#> [16] mvtnorm_1.1-3 lubridate_1.9.1 iBreakDown_2.0.1
#> [19] codetools_0.2-18 splines_4.2.2 cachem_1.0.6
#> [22] knitr_1.42 jsonlite_1.8.4 pROC_1.18.0
#> [25] cluster_2.1.4 png_0.1-8 compiler_4.2.2
#> [28] backports_1.4.1 Matrix_1.5-1 fastmap_1.1.0
#> [31] cli_3.6.0 htmltools_0.5.4 quantreg_5.94
#> [34] tools_4.2.2 gtable_0.3.1 glue_1.6.2
#> [37] reshape2_1.4.4 dplyr_1.0.10 Rcpp_1.0.10
#> [40] jquerylib_0.1.4 pkgdown_2.0.7 vctrs_0.5.2
#> [43] nlme_3.1-160 iterators_1.0.14 timeDate_4022.108
#> [46] xfun_0.36 gower_1.0.1 stringr_1.5.0
#> [49] globals_0.16.2 timechange_0.2.0 lifecycle_1.0.3
#> [52] future_1.30.0 polspline_1.1.22 MASS_7.3-58.1
#> [55] zoo_1.8-11 scales_1.2.1 ipred_0.9-13
#> [58] ragg_1.2.5 parallel_4.2.2 sandwich_3.0-2
#> [61] RColorBrewer_1.1-3 yaml_2.3.7 memoise_2.0.1
#> [64] gridExtra_2.3 sass_0.4.5 rpart_4.1.19
#> [67] latticeExtra_0.6-30 stringi_1.7.12 highr_0.10
#> [70] desc_1.4.2 foreach_1.5.2 checkmate_2.1.0
#> [73] hardhat_1.2.0 lava_1.7.1 rlang_1.0.6
#> [76] pkgconfig_2.0.3 systemfonts_1.0.4 evaluate_0.20
#> [79] purrr_1.0.1 recipes_1.0.4 htmlwidgets_1.6.1
#> [82] labeling_0.4.2 tidyselect_1.2.0 parallelly_1.34.0
#> [85] plyr_1.8.8 magrittr_2.0.3 R6_2.5.1
#> [88] generics_0.1.3 multcomp_1.4-20 pillar_1.8.1
#> [91] foreign_0.8-83 withr_2.5.0 nnet_7.3-18
#> [94] tibble_3.1.8 future.apply_1.10.0 interp_1.1-3
#> [97] utf8_1.2.2 rmarkdown_2.20 jpeg_0.1-10
#> [100] ingredients_2.3.0 grid_4.2.2 data.table_1.14.6
#> [103] ModelMetrics_1.2.2.2 digest_0.6.31 textshaping_0.3.6
#> [106] stats4_4.2.2 munsell_0.5.0 bslib_0.4.2