Data for Titanic survival

Let’s see an example for DALEX package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic avaliable in teh DALEX package. Note that this data was copied from the stablelearner package.

library("DALEX")
head(titanic)
#>   gender age class    embarked       country  fare sibsp parch survived
#> 1   male  42   3rd Southampton United States  7.11     0     0       no
#> 2   male  13   3rd Southampton United States 20.05     0     2       no
#> 3   male  16   3rd Southampton United States 20.05     1     1       no
#> 4 female  39   3rd Southampton       England 20.05     1     1      yes
#> 5 female  16   3rd Southampton        Norway  7.13     0     0      yes
#> 6   male  25   3rd Southampton United States  7.13     0     0      yes

Model for Titanic survival

Ok, not it’s time to create a model. Let’s use the Random Forest model.

#> 
#> Call:
#>  randomForest(formula = survived == "yes" ~ gender + age + class +      embarked + fare + sibsp + parch, data = titanic) 
#>                Type of random forest: regression
#>                      Number of trees: 500
#> No. of variables tried at each split: 2
#> 
#>           Mean of squared residuals: 0.1431399
#>                     % Var explained: 34.69

Explainer for Titanic survival

The third step (it’s optional but useful) is to create a DALEX explainer for random forest model.

#> Preparation of a new explainer is initiated
#>   -> model label       :  Random Forest v7 
#>   -> data              :  2099  rows  8  cols 
#>   -> target variable   :  2099  values 
#>   -> predict function  :  yhat.randomForest  will be used (  default  )
#>   -> predicted values  :  numerical, min =  0.0083465 , mean =  0.3243103 , max =  0.9954697  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.7914133 , mean =  0.0001299514 , max =  0.9000925  
#>   -> model_info        :  package randomForest , ver. 4.6.14 , task regression (  default  ) 
#>   A new explainer has been created!

Variable importance plots

Use the variable_importance() explainer to present importance of particular features. Note that type = "difference" normalizes dropouts, and now they all start in 0.

vi_rf <- variable_importance(explain_titanic_rf)
head(vi_rf)
#>       variable dropout_loss            label
#> 1 _full_model_     105.9405 Random Forest v7
#> 2      country     105.9405 Random Forest v7
#> 3        parch     112.8628 Random Forest v7
#> 4        sibsp     114.8572 Random Forest v7
#> 5     embarked     116.8550 Random Forest v7
#> 6         fare     133.4371 Random Forest v7
plot(vi_rf)

Variable effects

As we see the most important feature is Sex. Next three importnat features are Pclass, Age and Fare. Let’s see the link between model response and these features.

Such univariate relation can be calculated with variable_response().

Age

Kids 5 years old and younger have much higher survival probability.

vr_age  <- variable_response(explain_titanic_rf, variable =  "age")
head(vr_age)
#>           x            y var type            label
#> 1 0.1666667 0.531829.... age  pdp Random Forest v7
#> 2 1.6433333 0.557607.... age  pdp Random Forest v7
#> 3 3.1200000 0.571812.... age  pdp Random Forest v7
#> 4 4.5966667 0.544413.... age  pdp Random Forest v7
#> 5 6.0733333 0.528195.... age  pdp Random Forest v7
#> 6 7.5500000 0.529058.... age  pdp Random Forest v7
plot(vr_age, use_facets = TRUE)

Passanger class

Passangers in the first class have much higher survival probability.

vr_class  <- variable_response(explain_titanic_rf, variable =  "class")
plot(vr_class)

Fare

Very cheap tickets are linked with lower chances.

vr_fare  <- variable_response(explain_titanic_rf, variable =  "fare")
plot(vr_fare, use_facets = TRUE)

Embarked

Passangers that embarked from C have highest survival.

vr_embarked  <- variable_response(explain_titanic_rf, variable =  "embarked")
plot(vr_embarked)

Instance level explanations

Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.

new_passanger <- data.frame(
  class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
  gender = factor("male", levels = c("female", "male")),
  age = 8,
  sibsp = 0,
  parch = 0,
  fare = 72,
  embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)

sp_rf <- single_prediction(explain_titanic_rf, new_passanger)
plot(sp_rf)

It looks like the most important feature for this passenger is age and sex. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite of being a male.

More models

Let’s train more models for survival.

Logistic regression

#> Preparation of a new explainer is initiated
#>   -> model label       :  Logistic regression 
#>   -> data              :  2099  rows  9  cols 
#>   -> target variable   :  2099  values 
#>   -> predict function  :  function(m, x) predict(m, x, type = "fitted") 
#>   -> predicted values  :  numerical, min =  0.003695743 , mean =  0.3244402 , max =  0.9827164  
#>   -> residual function :  difference between y and yhat ( [33m default [39m )
#>   -> residuals         :  numerical, min =  -0.9827164 , mean =  -2.869171e-09 , max =  0.9716889  
#>   -> model_info        :  package stats , ver. 3.6.1 , task regression ( [33m default [39m ) 
#>  [32m A new explainer has been created! [39m

Generalized Boosted Models (GBM)

#> Distribution not specified, assuming bernoulli ...
#> Preparation of a new explainer is initiated
#>   -> model label       :  Generalized Boosted Models 
#>   -> data              :  2099  rows  9  cols 
#>   -> target variable   :  2099  values 
#>   -> predict function  :  function(m, x) predict(m, x, n.trees = 15000, type = "response") 
#>   -> predicted values  :  numerical, min =  0.0001421762 , mean =  0.3250518 , max =  0.9987503  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.9244986 , mean =  -0.0006115575 , max =  0.9628557  
#>   -> model_info        :  package gbm , ver. 2.1.5 , task classification (  default  ) 
#>   A new explainer has been created!

Support Vector Machines (SVM)

#> Preparation of a new explainer is initiated
#>   -> model label       :  Support Vector Machines 
#>   -> data              :  2099  rows  9  cols 
#>   -> target variable   :  2099  values 
#>   -> predict function  :  yhat.svm  will be used (  default  )
#>   -> predicted values  :  numerical, min =  0.08568572 , mean =  0.3255739 , max =  0.9571349  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.8537957 , mean =  -0.001133667 , max =  0.8913856  
#>   -> model_info        :  package e1071 , ver. 1.7.2 , task classification (  default  ) 
#>   A new explainer has been created!

k-Nearest Neighbours (kNN)

#> Preparation of a new explainer is initiated
#>   -> model label       :  k-Nearest Neighbours 
#>   -> data              :  2099  rows  9  cols 
#>   -> target variable   :  2099  values 
#>   -> predict function  :  function(m, x) predict(m, x)[, 2] 
#>   -> predicted values  :  numerical, min =  0 , mean =  0.3110276 , max =  1  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -0.8 , mean =  0.01341259 , max =  0.9285714  
#>   -> model_info        :  package Model of class: knn3 package unrecognized , ver. Unknown , task regression (  default  ) 
#>   A new explainer has been created!

Variable performance

vi_rf <- variable_importance(explain_titanic_rf)
vi_lmr <- variable_importance(explain_titanic_lmr)
vi_gbm <- variable_importance(explain_titanic_gbm)
vi_svm <- variable_importance(explain_titanic_svm)
vi_knn <- variable_importance(explain_titanic_knn)

plot(vi_rf, vi_lmr, vi_gbm, vi_svm, vi_knn, bar_width = 4)

Single variable

vr_age_rf  <- variable_response(explain_titanic_rf, variable =  "age")
vr_age_lmr  <- variable_response(explain_titanic_lmr, variable =  "age")
vr_age_gbm  <- variable_response(explain_titanic_gbm, variable =  "age")
vr_age_svm  <- variable_response(explain_titanic_svm, variable =  "age")
vr_age_knn  <- variable_response(explain_titanic_knn, variable =  "age")
plot(vr_age_rf, vr_age_lmr, vr_age_gbm, vr_age_svm, vr_age_knn)

plot(vr_age_rf, vr_age_lmr, vr_age_gbm, vr_age_svm, vr_age_knn, use_facets = TRUE)

Instance level explanations

sp_rf <- single_prediction(explain_titanic_rf, new_passanger)
sp_lmr <- single_prediction(explain_titanic_lmr, new_passanger)
sp_gbm <- single_prediction(explain_titanic_gbm, new_passanger)
sp_svm <- single_prediction(explain_titanic_svm, new_passanger)
sp_knn <- single_prediction(explain_titanic_knn, new_passanger)
plot(sp_rf, sp_lmr, sp_gbm, sp_svm, sp_knn)

Session info

#> R version 3.6.1 (2019-07-05)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Mojave 10.14.4
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] caret_6.0-84        e1071_1.7-2         gbm_2.1.5          
#>  [4] rms_5.1-3.1         SparseM_1.77        Hmisc_4.2-0        
#>  [7] ggplot2_3.2.1       Formula_1.2-3       survival_2.44-1.1  
#> [10] lattice_0.20-38     randomForest_4.6-14 DALEX_0.4.9        
#> 
#> loaded via a namespace (and not attached):
#>   [1] TH.data_1.0-10      colorspace_1.4-1    ggsignif_0.6.0     
#>   [4] deldir_0.1-23       class_7.3-15        rprojroot_1.3-2    
#>   [7] htmlTable_1.13.2    base64enc_0.1-3     fs_1.3.1           
#>  [10] rstudioapi_0.10     proxy_0.4-23        ggpubr_0.2.3       
#>  [13] MatrixModels_0.4-1  lubridate_1.7.4     prodlim_2018.04.18 
#>  [16] mvtnorm_1.0-11      codetools_0.2-16    splines_3.6.1      
#>  [19] knitr_1.25          cluster_2.1.0       shiny_1.4.0        
#>  [22] compiler_3.6.1      backports_1.1.5     assertthat_0.2.1   
#>  [25] Matrix_1.2-17       fastmap_1.0.1       lazyeval_0.2.2     
#>  [28] later_1.0.0         acepack_1.4.1       htmltools_0.4.0    
#>  [31] quantreg_5.51       tools_3.6.1         coda_0.19-3        
#>  [34] gtable_0.3.0        agricolae_1.3-1     glue_1.3.1         
#>  [37] reshape2_1.4.3      dplyr_0.8.3         gmodels_2.18.1     
#>  [40] Rcpp_1.0.2          pkgdown_1.4.1       spdep_1.1-3        
#>  [43] gdata_2.18.0        nlme_3.1-140        iterators_1.0.12   
#>  [46] timeDate_3043.102   gower_0.2.1         xfun_0.10          
#>  [49] stringr_1.4.0       mime_0.7            miniUI_0.1.1.1     
#>  [52] breakDown_0.1.6     gtools_3.8.1        polspline_1.1.16   
#>  [55] zoo_1.8-6           LearnBayes_2.15.1   MASS_7.3-51.4      
#>  [58] scales_1.0.0        ipred_0.9-9         promises_1.1.0     
#>  [61] sandwich_2.5-1      expm_0.999-4        RColorBrewer_1.1-2 
#>  [64] yaml_2.2.0          memoise_1.1.0       gridExtra_2.3      
#>  [67] rpart_4.1-15        latticeExtra_0.6-28 stringi_1.4.3      
#>  [70] highr_0.8           klaR_0.6-14         AlgDesign_1.1-7.3  
#>  [73] desc_1.2.0          foreach_1.4.7       checkmate_1.9.4    
#>  [76] boot_1.3-22         lava_1.6.6          spData_0.3.2       
#>  [79] rlang_0.4.0         pkgconfig_2.0.3     evaluate_0.14      
#>  [82] purrr_0.3.3         sf_0.8-0            recipes_0.1.7      
#>  [85] htmlwidgets_1.5.1   labeling_0.3        cowplot_1.0.0      
#>  [88] tidyselect_0.2.5    factorMerger_0.4.0  plyr_1.8.4         
#>  [91] magrittr_1.5        R6_2.4.0            generics_0.0.2     
#>  [94] multcomp_1.4-10     combinat_0.0-8      DBI_1.0.0          
#>  [97] pillar_1.4.2        foreign_0.8-71      withr_2.1.2        
#> [100] units_0.6-5         sp_1.3-1            nnet_7.3-12        
#> [103] tibble_2.1.3        crayon_1.3.4        questionr_0.7.0    
#> [106] KernSmooth_2.23-15  rmarkdown_1.16      grid_3.6.1         
#> [109] data.table_1.12.2   ModelMetrics_1.2.2  digest_0.6.21      
#> [112] classInt_0.4-1      pdp_0.7.0           xtable_1.8-4       
#> [115] httpuv_1.5.2        stats4_3.6.1        munsell_0.5.0