Data for Titanic survival

Let’s see an example for DALEX package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic available in the DALEX package. Note that this data was copied from the stablelearner package.

library("DALEX")
head(titanic_imputed)
#>   gender age class    embarked  fare sibsp parch survived
#> 1   male  42   3rd Southampton  7.11     0     0        0
#> 2   male  13   3rd Southampton 20.05     0     2        0
#> 3   male  16   3rd Southampton 20.05     1     1        0
#> 4 female  39   3rd Southampton 20.05     1     1        1
#> 5 female  16   3rd Southampton  7.13     0     0        1
#> 6   male  25   3rd Southampton  7.13     0     0        1

Model for Titanic survival

Ok, not it’s time to create a model. Let’s use the Random Forest model.

# prepare model
library("ranger")
model_titanic_rf <- ranger(survived ~ gender + age + class + embarked +
                                   fare + sibsp + parch,  data = titanic_imputed,
                           classification = TRUE)
model_titanic_rf
#> Ranger result
#> 
#> Call:
#>  ranger(survived ~ gender + age + class + embarked + fare + sibsp +      parch, data = titanic_imputed, classification = TRUE) 
#> 
#> Type:                             Classification 
#> Number of trees:                  500 
#> Sample size:                      2207 
#> Number of independent variables:  7 
#> Mtry:                             2 
#> Target node size:                 1 
#> Variable importance mode:         none 
#> Splitrule:                        gini 
#> OOB prediction error:             19.57 %

Explainer for Titanic survival

The third step (it’s optional but useful) is to create a DALEX explainer for the random forest model.

library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf, 
                      data = titanic_imputed,
                      y = titanic_imputed$survived, 
                      label = "Random Forest v7",
                      colorize = FALSE)
#> Preparation of a new explainer is initiated
#>   -> model label       :  Random Forest v7 
#>   -> data              :  2207  rows  8  cols 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  yhat.ranger  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package ranger , ver. 0.13.1 , task classification (  default  ) 
#>   -> predicted values  :  numerical, min =  0 , mean =  0.2070684 , max =  1  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -1 , mean =  0.1150884 , max =  1  
#>   A new explainer has been created!

Variable importance plots

Use the variable_importance() explainer to present the importance of particular features. Note that type = "difference" normalizes dropouts, and now they all start in 0.

vi_rf <- model_parts(explain_titanic_rf)
head(vi_rf)
#>       variable mean_dropout_loss            label
#> 1 _full_model_         0.2206452 Random Forest v7
#> 2     survived         0.2204889 Random Forest v7
#> 3        parch         0.2372033 Random Forest v7
#> 4        sibsp         0.2384049 Random Forest v7
#> 5     embarked         0.2412000 Random Forest v7
#> 6         fare         0.2627507 Random Forest v7
plot(vi_rf)

Variable effects

As we see the most important feature is Sex. Next three important features are Pclass, Age and Fare. Let’s see the link between model response and these features.

Such univariate relation can be calculated with variable_effect().

Age

Kids 5 years old and younger have a much higher survival probability.

vr_age  <- model_profile(explain_titanic_rf, variables =  "age")
head(vr_age)
#> $cp_profiles
#> Top profiles    : 
#>        gender        age            class embarked fare sibsp parch survived
#> 2054     male  0.1666667 engineering crew  Belfast    0     0     0        0
#> 2054.1   male  2.0000000 engineering crew  Belfast    0     0     0        0
#> 2054.2   male  4.0000000 engineering crew  Belfast    0     0     0        0
#> 2054.3   male  7.0000000 engineering crew  Belfast    0     0     0        0
#> 2054.4   male  9.0000000 engineering crew  Belfast    0     0     0        0
#> 2054.5   male 13.0000000 engineering crew  Belfast    0     0     0        0
#>        _yhat_ _vname_ _ids_          _label_
#> 2054        0     age  2054 Random Forest v7
#> 2054.1      0     age  2054 Random Forest v7
#> 2054.2      0     age  2054 Random Forest v7
#> 2054.3      0     age  2054 Random Forest v7
#> 2054.4      0     age  2054 Random Forest v7
#> 2054.5      0     age  2054 Random Forest v7
#> 
#> 
#> Top observations:
#>      gender age            class    embarked    fare sibsp parch survived
#> 2054   male  31 engineering crew     Belfast  0.0000     0     0        0
#> 2011   male  28 victualling crew Southampton  0.0000     0     0        0
#> 450    male  31              2nd Southampton 13.0000     0     0        0
#> 280    male  20              3rd Southampton  8.0302     0     0        0
#> 1702   male  53        deck crew Southampton  0.0000     0     0        1
#> 771    male  23              3rd Southampton  7.1711     0     0        0
#>      _yhat_          _label_ _ids_
#> 2054      0 Random Forest v7     1
#> 2011      0 Random Forest v7     2
#> 450       0 Random Forest v7     3
#> 280       0 Random Forest v7     4
#> 1702      0 Random Forest v7     5
#> 771       0 Random Forest v7     6
#> 
#> $agr_profiles
#> Top profiles    : 
#>   _vname_          _label_        _x_ _yhat_ _ids_
#> 1     age Random Forest v7  0.1666667   0.32     0
#> 2     age Random Forest v7  2.0000000   0.54     0
#> 3     age Random Forest v7  4.0000000   0.58     0
#> 4     age Random Forest v7  7.0000000   0.39     0
#> 5     age Random Forest v7  9.0000000   0.34     0
#> 6     age Random Forest v7 13.0000000   0.14     0
#> 
#> $color
#> [1] "#4378bf"
plot(vr_age)

Passenger class

Passengers in the first-class have much higher survival probability.

vr_class  <- model_profile(explain_titanic_rf, variables =  "class")
plot(vr_class)

Fare

Very cheap tickets are linked with lower chances.

vr_fare  <- variable_profile(explain_titanic_rf, variables =  "fare")
plot(vr_fare)

Embarked

Passengers that embarked from C have the highest survival.

vr_embarked  <- model_profile(explain_titanic_rf, variables =  "embarked")
plot(vr_embarked)

Instance level explanations

Let’s see break-down explanation for model predictions for 8 years old male from 1st class that embarked from port C.

new_passanger <- data.frame(
  class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
  gender = factor("male", levels = c("female", "male")),
  age = 8,
  sibsp = 0,
  parch = 0,
  fare = 72,
  embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)

sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)

It looks like the most important feature for this passenger is age and sex. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite being a male.

More models

Let’s train more models for survival.

Logistic regression

library("rms")
model_titanic_lmr <- lrm(survived ~ class + gender + rcs(age) + sibsp +
                   parch + fare + embarked, titanic_imputed)
explain_titanic_lmr <- explain(model_titanic_lmr, data = titanic_imputed, 
                   y = titanic_imputed$survived, 
                   predict_function = function(m,x) 
                            predict(m, x, type = "fitted"),
                   label = "Logistic regression")
#> Preparation of a new explainer is initiated
#>   -> model label       :  Logistic regression 
#>   -> data              :  2207  rows  8  cols 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  function(m, x) predict(m, x, type = "fitted") 
#>   -> predicted values  :  No value for predict function target column. ( [33m default [39m )
#>   -> model_info        :  package rms , ver. 6.2.0 , task classification ( [33m default [39m ) 
#>   -> predicted values  :  numerical, min =  0.002671631 , mean =  0.3221568 , max =  0.9845724  
#>   -> residual function :  difference between y and yhat ( [33m default [39m )
#>   -> residuals         :  numerical, min =  -0.9845724 , mean =  -2.491758e-09 , max =  0.9715125  
#>  [32m A new explainer has been created! [39m

Generalized Boosted Models (GBM)

library("gbm")
model_titanic_gbm <- gbm(survived  ~ class + gender + age + sibsp +
                     parch + fare + embarked, data = titanic_imputed, 
                     n.trees = 15000)
#> Distribution not specified, assuming bernoulli ...
explain_titanic_gbm <- explain(model_titanic_gbm, data = titanic_imputed, 
                       y = titanic_imputed$survived, 
                       predict_function = function(m,x) 
                            predict(m, x, n.trees = 15000, type = "response"),
                       label = "Generalized Boosted Models",
                       colorize = FALSE)
#> Preparation of a new explainer is initiated
#>   -> model label       :  Generalized Boosted Models 
#>   -> data              :  2207  rows  8  cols 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  function(m, x) predict(m, x, n.trees = 15000, type = "response") 
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package gbm , ver. 2.1.8 , task classification (  default  ) 
#>   -> predicted values  :  numerical, min =  0.0002495798 , mean =  0.3188276 , max =  0.9987135  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.962244 , mean =  0.003329132 , max =  0.9933339  
#>   A new explainer has been created!

Support Vector Machines (SVM)

library("e1071")
model_titanic_svm <- svm(survived ~ class + gender + age + sibsp +
                     parch + fare + embarked, data = titanic_imputed, 
             type = "C-classification", probability = TRUE)
explain_titanic_svm <- explain(model_titanic_svm, data = titanic_imputed, 
                       y = titanic_imputed$survived, 
                       label = "Support Vector Machines",
                       colorize = FALSE)
#> Preparation of a new explainer is initiated
#>   -> model label       :  Support Vector Machines 
#>   -> data              :  2207  rows  8  cols 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  yhat.svm  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package e1071 , ver. 1.7.9 , task classification (  default  ) 
#>   -> predicted values  :  numerical, min =  0.08750959 , mean =  0.3222721 , max =  0.9616319  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.8634896 , mean =  -0.0001153752 , max =  0.9124904  
#>   A new explainer has been created!

k-Nearest Neighbors (kNN)

library("caret")
model_titanic_knn <- knn3(survived ~ class + gender + age + sibsp +
                     parch + fare + embarked, data = titanic_imputed, k = 5)
explain_titanic_knn <- explain(model_titanic_knn, data = titanic_imputed, 
                       y = titanic_imputed$survived, 
                       predict_function = function(m,x) predict(m, x)[,2],
                       label = "k-Nearest Neighbors",
                       colorize = FALSE)
#> Preparation of a new explainer is initiated
#>   -> model label       :  k-Nearest Neighbors 
#>   -> data              :  2207  rows  8  cols 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  function(m, x) predict(m, x)[, 2] 
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package Model of class: knn3 package unrecognized , ver. Unknown , task regression (  default  ) 
#>   -> predicted values  :  numerical, min =  0 , mean =  0.3061413 , max =  1  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.8333333 , mean =  0.01601551 , max =  0.9285714  
#>   A new explainer has been created!

Variable performance

vi_rf <- model_parts(explain_titanic_rf)
vi_lmr <- model_parts(explain_titanic_lmr)
vi_gbm <- model_parts(explain_titanic_gbm)
vi_svm <- model_parts(explain_titanic_svm)
vi_knn <- model_parts(explain_titanic_knn)

plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)

Single variable

vr_age_rf   <- model_profile(explain_titanic_rf, variables = "age")
vr_age_lmr  <- model_profile(explain_titanic_lmr, variables = "age")
vr_age_gbm  <- model_profile(explain_titanic_gbm, variables = "age")
vr_age_svm  <- model_profile(explain_titanic_svm, variables = "age")
vr_age_knn  <- model_profile(explain_titanic_knn, variables = "age")
plot(vr_age_rf$agr_profiles, 
     vr_age_lmr$agr_profiles, 
     vr_age_gbm$agr_profiles, 
     vr_age_svm$agr_profiles, 
     vr_age_knn$agr_profiles)

Instance level explanations

sp_rf <- predict_parts(explain_titanic_rf, new_passanger)
plot(sp_rf)

sp_lmr <- predict_parts(explain_titanic_lmr, new_passanger)
plot(sp_lmr)

sp_gbm <- predict_parts(explain_titanic_gbm, new_passanger)
plot(sp_gbm)

sp_svm <- predict_parts(explain_titanic_svm, new_passanger)
plot(sp_svm)

sp_knn <- predict_parts(explain_titanic_knn, new_passanger)
plot(sp_knn)

Session info

#> R version 4.1.2 (2021-11-01)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Catalina 10.15.7
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] caret_6.0-90     e1071_1.7-9      gbm_2.1.8        rms_6.2-0       
#>  [5] SparseM_1.81     Hmisc_4.6-0      ggplot2_3.3.5    Formula_1.2-4   
#>  [9] survival_3.2-13  lattice_0.20-45  ranger_0.13.1    DALEX_2.3.0.9000
#> 
#> loaded via a namespace (and not attached):
#>   [1] TH.data_1.1-0        colorspace_2.0-2     ellipsis_0.3.2      
#>   [4] class_7.3-19         rprojroot_2.0.2      htmlTable_2.3.0     
#>   [7] base64enc_0.1-3      fs_1.5.0             proxy_0.4-26        
#>  [10] rstudioapi_0.13      listenv_0.8.0        farver_2.1.0        
#>  [13] MatrixModels_0.5-0   prodlim_2019.11.13   fansi_0.5.0         
#>  [16] mvtnorm_1.1-3        lubridate_1.8.0      iBreakDown_2.0.1    
#>  [19] codetools_0.2-18     splines_4.1.2        cachem_1.0.6        
#>  [22] knitr_1.36           pROC_1.18.0          cluster_2.1.2       
#>  [25] png_0.1-7            compiler_4.1.2       backports_1.3.0     
#>  [28] Matrix_1.3-4         fastmap_1.1.0        htmltools_0.5.2     
#>  [31] quantreg_5.86        tools_4.1.2          gtable_0.3.0        
#>  [34] glue_1.5.0           reshape2_1.4.4       dplyr_1.0.7         
#>  [37] Rcpp_1.0.7           jquerylib_0.1.4      pkgdown_1.6.1       
#>  [40] vctrs_0.3.8          nlme_3.1-153         conquer_1.2.1       
#>  [43] iterators_1.0.13     timeDate_3043.102    gower_0.2.2         
#>  [46] xfun_0.28            stringr_1.4.0        globals_0.14.0      
#>  [49] lifecycle_1.0.1      future_1.23.0        polspline_1.1.19    
#>  [52] MASS_7.3-54          zoo_1.8-9            scales_1.1.1        
#>  [55] ipred_0.9-12         ragg_1.2.0           parallel_4.1.2      
#>  [58] sandwich_3.0-1       RColorBrewer_1.1-2   yaml_2.2.1          
#>  [61] memoise_2.0.0        gridExtra_2.3        rpart_4.1-15        
#>  [64] latticeExtra_0.6-29  stringi_1.7.5        highr_0.9           
#>  [67] desc_1.4.0           foreach_1.5.1        checkmate_2.0.0     
#>  [70] lava_1.6.10          rlang_0.4.12         pkgconfig_2.0.3     
#>  [73] systemfonts_1.0.3    matrixStats_0.61.0   evaluate_0.14       
#>  [76] purrr_0.3.4          recipes_0.1.17       htmlwidgets_1.5.4   
#>  [79] labeling_0.4.2       tidyselect_1.1.1     parallelly_1.28.1   
#>  [82] plyr_1.8.6           magrittr_2.0.1       R6_2.5.1            
#>  [85] generics_0.1.1       multcomp_1.4-17      pillar_1.6.4        
#>  [88] foreign_0.8-81       withr_2.4.2          nnet_7.3-16         
#>  [91] tibble_3.1.5         future.apply_1.8.1   crayon_1.4.2        
#>  [94] utf8_1.2.2           rmarkdown_2.11       jpeg_0.1-9          
#>  [97] ingredients_2.2.0    grid_4.1.2           data.table_1.14.2   
#> [100] ModelMetrics_1.2.2.2 digest_0.6.28        textshaping_0.3.6   
#> [103] stats4_4.1.2         munsell_0.5.0