Data for Titanic survival

Let’s see an example for DALEX package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic avaliable in the DALEX package. Note that this data was copied from the stablelearner package.

library("DALEX")
head(titanic)
#>   gender age class    embarked       country  fare sibsp parch survived
#> 1   male  42   3rd Southampton United States  7.11     0     0       no
#> 2   male  13   3rd Southampton United States 20.05     0     2       no
#> 3   male  16   3rd Southampton United States 20.05     1     1       no
#> 4 female  39   3rd Southampton       England 20.05     1     1      yes
#> 5 female  16   3rd Southampton        Norway  7.13     0     0      yes
#> 6   male  25   3rd Southampton United States  7.13     0     0      yes

Model for Titanic survival

Ok, now it’s time to create a model. Let’s use the Random Forest model.

#> 
#> Call:
#>  randomForest(formula = survived == "yes" ~ gender + age + class +      embarked + fare + sibsp + parch, data = titanic) 
#>                Type of random forest: regression
#>                      Number of trees: 500
#> No. of variables tried at each split: 2
#> 
#>           Mean of squared residuals: 0.1430171
#>                     % Var explained: 34.75

Explainer for Titanic survival

The third step (it’s optional but useful) is to create a DALEX explainer for random forest model.

#> Preparation of a new explainer is initiated
#>   -> model label       :  Random Forest v7 
#>   -> data              :  2099  rows  8  cols 
#>   -> target variable   :  2099  values 
#>   -> predict function  :  yhat.randomForest  will be used ( [33m default [39m )
#>   -> predicted values  :  numerical, min =  0.008617165 , mean =  0.3243438 , max =  0.9908865  
#>   -> residual function :  difference between y and yhat ( [33m default [39m )
#>   -> residuals         :  numerical, min =  -0.7965849 , mean =  9.642274e-05 , max =  0.9089863  
#>   -> model_info        :  package randomForest , ver. 4.6.14 , task regression ( [33m default [39m ) 
#>  [32m A new explainer has been created! [39m

Model Level Feature Importance

Use the feature_importance() explainer to present importance of particular features. Note that type = "difference" normalizes dropouts, and now they all start in 0.

library("ingredients")

fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
#>       variable dropout_loss            label
#> 1 _full_model_    0.3331537 Random Forest v7
#> 2      country    0.3331537 Random Forest v7
#> 3        parch    0.3440463 Random Forest v7
#> 4        sibsp    0.3453025 Random Forest v7
#> 5     embarked    0.3488268 Random Forest v7
#> 6         fare    0.3741211 Random Forest v7
plot(fi_rf)

Feature effects

As we see the most important feature is gender. Next three importnat features are class, age and fare. Let’s see the link between model response and these features.

Such univariate relation can be calculated with partial_dependency().

age

Kids 5 years old and younger have much higher survival probability.

Partial Dependency Profiles

pp_age  <- partial_dependency(explain_titanic_rf, variables =  c("age", "fare"))
head(pp_age)
#> Top profiles    : 
#>   _vname_          _label_       _x_    _yhat_ _ids_
#> 1    fare Random Forest v7 0.0000000 0.3240270     0
#> 2     age Random Forest v7 0.1666667 0.5139869     0
#> 3     age Random Forest v7 2.0000000 0.5447859     0
#> 4     age Random Forest v7 4.0000000 0.5524095     0
#> 5    fare Random Forest v7 6.1904000 0.3102971     0
#> 6     age Random Forest v7 7.0000000 0.5185248     0
plot(pp_age)

Conditional Dependency Profiles

cp_age  <- conditional_dependency(explain_titanic_rf, variables =  c("age", "fare"))
plot(cp_age)

Accumulated Local Effect Profiles

ap_age  <- accumulated_dependency(explain_titanic_rf, variables =  c("age", "fare"))
plot(ap_age)

Instance level explanations

Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.

First Ceteris Paribus Profiles for numerical variables

new_passanger <- data.frame(
  class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
  gender = factor("male", levels = c("female", "male")),
  age = 8,
  sibsp = 0,
  parch = 0,
  fare = 72,
  embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)

sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
  show_observations(sp_rf)

And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.

plot(sp_rf,
     variables = c("class", "embarked", "gender", "sibsp"),
     variable_type = "categorical")

It looks like the most important feature for this passenger is age and sex. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite of being a male.

Profile clustering

passangers <- select_sample(titanic, n = 100)

sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
#> Top profiles    : 
#>   _vname_            _label_       _x_ _cluster_    _yhat_ _ids_
#> 1    fare Random Forest v7_1 0.0000000         1 0.1961479     0
#> 2   sibsp Random Forest v7_1 0.0000000         1 0.1689586     0
#> 3   parch Random Forest v7_1 0.0000000         1 0.1723925     0
#> 4     age Random Forest v7_1 0.1666667         1 0.4572311     0
#> 5   parch Random Forest v7_1 0.2800000         1 0.1723925     0
#> 6   sibsp Random Forest v7_1 1.0000000         1 0.1706048     0
plot(sp_rf, alpha = 0.1) +
  show_aggregated_profiles(clust_rf, color = "_label_", size = 2)

Session info

#> R version 3.6.1 (2019-07-05)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Mojave 10.14.4
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ingredients_0.4     randomForest_4.6-14 DALEX_0.4.9        
#> 
#> loaded via a namespace (and not attached):
#>  [1] Rcpp_1.0.2       compiler_3.6.1   pillar_1.4.2     tools_3.6.1     
#>  [5] digest_0.6.21    evaluate_0.14    memoise_1.1.0    tibble_2.1.3    
#>  [9] gtable_0.3.0     pkgconfig_2.0.3  rlang_0.4.0      rstudioapi_0.10 
#> [13] yaml_2.2.0       pkgdown_1.4.1    xfun_0.10        stringr_1.4.0   
#> [17] dplyr_0.8.3      knitr_1.25       desc_1.2.0       fs_1.3.1        
#> [21] rprojroot_1.3-2  grid_3.6.1       tidyselect_0.2.5 glue_1.3.1      
#> [25] R6_2.4.0         rmarkdown_1.16   ggplot2_3.2.1    purrr_0.3.3     
#> [29] magrittr_1.5     backports_1.1.5  scales_1.0.0     htmltools_0.4.0 
#> [33] MASS_7.3-51.4    assertthat_0.2.1 colorspace_1.4-1 labeling_0.3    
#> [37] stringi_1.4.3    lazyeval_0.2.2   munsell_0.5.0    crayon_1.3.4