This function calculates variable importance as a change in the loss function after the variable values permutations.

model_parts(explainer, ...)

# S3 method for surv_explainer
model_parts(
  explainer,
  loss_function = survex::loss_brier_score,
  ...,
  type = "difference",
  output_type = "survival",
  N = 1000
)

Arguments

explainer

an explainer object - model preprocessed by the explain() function

...

Arguments passed on to surv_feature_importance, surv_integrated_feature_importance

B

numeric, number of permutations to be calculated

variables

a character vector, names of variables to be included in the calculation

variable_groups

a list of character vectors of names of explanatory variables. For each vector, a single variable-importance measure is computed for the joint effect of the variables which names are provided in the vector. By default, variable_groups = NULL, in which case variable-importance measures are computed separately for all variables indicated in the variables argument

label

label of the model, if provides overrides x$label

loss_function

a function that will be used to assess variable importance, by default loss_brier_score for survival models. The function can be supplied manually but has to have these named parameters (y_true, risk, surv, times), where y_true represents the survival::Surv object with observed times and statuses, risk is the risk score calculated by the model, and surv is the survival function for each observation evaluated at times.

type

a character vector, if "raw" the results are losses after the permutation, if "ratio" the results are in the form loss/loss_full_model and if "difference" the results are of the form loss - loss_full_model. Defaults to "difference".

output_type

either "survival" or "risk" the type of survival model output that should be used for explanations. If "survival" the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the DALEX::model_profile function.

N

number of observations that should be sampled for calculation of variable importance. If NULL then variable importance will be calculated on the whole dataset.

Value

An object of class c("model_parts_survival", "surv_feature_importance"). It's a list with the explanations in the result element.

Details

Note: This function can be run within progressr::with_progress() to display a progress bar, as the execution can take long, especially on large datasets.

Examples

# \donttest{
library(survival)
library(survex)

cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE)
rsf_ranger <- ranger::ranger(Surv(time, status) ~ .,
  data = veteran,
  respect.unordered.factors = TRUE,
  num.trees = 100,
  mtry = 3,
  max.depth = 5
)

cph_exp <- explain(cph)
#> Preparation of a new explainer is initiated 
#>   -> model label       :  coxph (  default  ) 
#>   -> data              :  137  rows  6  cols (  extracted from the model  ) 
#>   -> target variable   :  137  values ( 128 events and 9 censored , censoring rate = 0.066 ) (  extracted from the model  ) 
#>   -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
#>   -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
#>   -> predict function  :  predict.coxph with type = 'risk' will be used (  default  ) 
#>   -> predict survival function  :  predictSurvProb.coxph will be used (  default  ) 
#>   -> predict cumulative hazard function  :  -log(predict_survival_function) will be used (  default  ) 
#>   -> model_info        :  package survival , ver. 3.5.8 , task survival (  default  ) 
#>   A new explainer has been created!  

rsf_ranger_exp <- explain(rsf_ranger,
  data = veteran[, -c(3, 4)],
  y = Surv(veteran$time, veteran$status)
)
#> Preparation of a new explainer is initiated 
#>   -> model label       :  ranger (  default  ) 
#>   -> data              :  137  rows  6  cols 
#>   -> target variable   :  137  values ( 128 events and 9 censored ) 
#>   -> times             :  50 unique time points , min = 1.5 , median survival time = 80 , max = 999 
#>   -> times             :  (  generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator  ) 
#>   -> predict function  :  sum over the predict_cumulative_hazard_function will be used (  default  ) 
#>   -> predict survival function  :  stepfun based on predict.ranger()$survival will be used (  default  ) 
#>   -> predict cumulative hazard function  :  stepfun based on predict.ranger()$chf will be used (  default  ) 
#>   -> model_info        :  package ranger , ver. 0.16.0 , task survival (  default  ) 
#>   A new explainer has been created!  

cph_model_parts_brier <- model_parts(cph_exp)
print(head(cph_model_parts_brier$result))
#>    _times_ _full_model_           trt      celltype        karno      diagtime
#> 1      1.5            0  3.414508e-05 -0.0004391802 0.0001930351 -3.052674e-08
#> 12     4.0            0 -9.083787e-05 -0.0002643615 0.0020820016  7.604845e-07
#> 23     7.0            0  4.752907e-04  0.0013203051 0.0052460287  4.063020e-07
#> 35     8.0            0  2.247208e-04  0.0031117431 0.0102433899  5.050466e-06
#> 48    10.0            0 -3.788362e-04  0.0019621813 0.0137033300  7.451684e-06
#> 57    12.0            0 -1.265335e-03  0.0017288758 0.0149247381  8.220809e-06
#>             age         prior   _baseline_ _permutation_ label
#> 1  1.540466e-05  9.416094e-06 0.0001320184             0 coxph
#> 12 5.021007e-04  4.424765e-05 0.0021482684             0 coxph
#> 23 5.190239e-04 -3.183107e-05 0.0066581972             0 coxph
#> 35 5.216642e-04  1.909002e-04 0.0129284258             0 coxph
#> 48 5.817038e-04  2.948293e-04 0.0158831815             0 coxph
#> 57 5.084629e-04  5.295008e-04 0.0187495665             0 coxph
plot(cph_model_parts_brier)



rsf_ranger_model_parts <- model_parts(rsf_ranger_exp)
print(head(rsf_ranger_model_parts$result))
#>    _times_ _full_model_          trt    celltype       karno     diagtime
#> 1      1.5            0 0.0001627702 0.001154185 0.002194825 0.0007448171
#> 12     4.0            0 0.0010166134 0.002589573 0.009145478 0.0026006903
#> 23     7.0            0 0.0014333630 0.003611515 0.016965397 0.0040812810
#> 35     8.0            0 0.0019726986 0.006026738 0.025968757 0.0091819261
#> 48    10.0            0 0.0022475016 0.005699407 0.038798780 0.0110994975
#> 57    12.0            0 0.0025773798 0.007753140 0.046387337 0.0150343326
#>            age        prior  _baseline_ _permutation_  label
#> 1  0.002751134 0.0002786377 0.005572104             0 ranger
#> 12 0.011073665 0.0006944964 0.018361833             0 ranger
#> 23 0.014504751 0.0003502657 0.025381473             0 ranger
#> 35 0.014323920 0.0009249015 0.039236130             0 ranger
#> 48 0.015248906 0.0012535640 0.051156514             0 ranger
#> 57 0.015342076 0.0023220928 0.060363730             0 ranger
plot(cph_model_parts_brier, rsf_ranger_model_parts)

# }