R/model_parts.R
model_parts.surv_explainer.Rd
This function calculates variable importance as a change in the loss function after the variable values permutations.
model_parts(explainer, ...)
# S3 method for surv_explainer
model_parts(
explainer,
loss_function = survex::loss_brier_score,
...,
type = "difference",
output_type = "survival",
N = 1000
)
an explainer object - model preprocessed by the explain()
function
Arguments passed on to surv_feature_importance
, surv_integrated_feature_importance
B
numeric, number of permutations to be calculated
variables
a character vector, names of variables to be included in the calculation
variable_groups
a list of character vectors of names of explanatory variables. For each vector, a single variable-importance measure is computed for the joint effect of the variables which names are provided in the vector. By default, variable_groups = NULL, in which case variable-importance measures are computed separately for all variables indicated in the variables argument
label
label of the model, if provides overrides x$label
a function that will be used to assess variable importance, by default loss_brier_score
for survival models. The function can be supplied manually but has to have these named parameters (y_true
, risk
, surv
, times
), where y_true
represents the survival::Surv
object with observed times and statuses, risk
is the risk score calculated by the model, and surv
is the survival function for each observation evaluated at times
.
a character vector, if "raw"
the results are losses after the permutation, if "ratio"
the results are in the form loss/loss_full_model
and if "difference"
the results are of the form loss - loss_full_model
. Defaults to "difference"
.
either "survival"
or "risk"
the type of survival model output that should be used for explanations. If "survival"
the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the DALEX::model_profile
function.
number of observations that should be sampled for calculation of variable importance. If NULL
then variable importance will be calculated on the whole dataset.
An object of class c("model_parts_survival", "surv_feature_importance")
. It's a list with the explanations in the result
element.
Note: This function can be run within progressr::with_progress()
to display a progress bar, as the execution can take long, especially on large datasets.
# \donttest{
library(survival)
library(survex)
cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE)
rsf_ranger <- ranger::ranger(Surv(time, status) ~ .,
data = veteran,
respect.unordered.factors = TRUE,
num.trees = 100,
mtry = 3,
max.depth = 5
)
cph_exp <- explain(cph)
#> Preparation of a new explainer is initiated
#> -> model label : coxph ( default )
#> -> data : 137 rows 6 cols ( extracted from the model )
#> -> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : predict.coxph with type = 'risk' will be used ( default )
#> -> predict survival function : predictSurvProb.coxph will be used ( default )
#> -> predict cumulative hazard function : -log(predict_survival_function) will be used ( default )
#> -> model_info : package survival , ver. 3.7.0 , task survival ( default )
#> A new explainer has been created!
rsf_ranger_exp <- explain(rsf_ranger,
data = veteran[, -c(3, 4)],
y = Surv(veteran$time, veteran$status)
)
#> Preparation of a new explainer is initiated
#> -> model label : ranger ( default )
#> -> data : 137 rows 6 cols
#> -> target variable : 137 values ( 128 events and 9 censored )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : sum over the predict_cumulative_hazard_function will be used ( default )
#> -> predict survival function : stepfun based on predict.ranger()$survival will be used ( default )
#> -> predict cumulative hazard function : stepfun based on predict.ranger()$chf will be used ( default )
#> -> model_info : package ranger , ver. 0.16.0 , task survival ( default )
#> A new explainer has been created!
cph_model_parts_brier <- model_parts(cph_exp)
print(head(cph_model_parts_brier$result))
#> _times_ _full_model_ trt celltype karno diagtime
#> 1 1.5 0 3.414508e-05 -0.0004391802 0.0001930351 -3.052674e-08
#> 12 4.0 0 -9.083787e-05 -0.0002643615 0.0020820016 7.604845e-07
#> 23 7.0 0 4.752907e-04 0.0013203051 0.0052460287 4.063020e-07
#> 35 8.0 0 2.247208e-04 0.0031117431 0.0102433899 5.050466e-06
#> 48 10.0 0 -3.788362e-04 0.0019621813 0.0137033300 7.451684e-06
#> 57 12.0 0 -1.265335e-03 0.0017288758 0.0149247381 8.220809e-06
#> age prior _baseline_ _permutation_ label
#> 1 1.540466e-05 9.416094e-06 0.0001320184 0 coxph
#> 12 5.021007e-04 4.424765e-05 0.0021482684 0 coxph
#> 23 5.190239e-04 -3.183107e-05 0.0066581972 0 coxph
#> 35 5.216642e-04 1.909002e-04 0.0129284258 0 coxph
#> 48 5.817038e-04 2.948293e-04 0.0158831815 0 coxph
#> 57 5.084629e-04 5.295008e-04 0.0187495665 0 coxph
plot(cph_model_parts_brier)
rsf_ranger_model_parts <- model_parts(rsf_ranger_exp)
print(head(rsf_ranger_model_parts$result))
#> _times_ _full_model_ trt celltype karno diagtime
#> 1 1.5 0 0.0001627702 0.001154185 0.002194825 0.0007448171
#> 12 4.0 0 0.0010166134 0.002589573 0.009145478 0.0026006903
#> 23 7.0 0 0.0014333630 0.003611515 0.016965397 0.0040812810
#> 35 8.0 0 0.0019726986 0.006026738 0.025968757 0.0091819261
#> 48 10.0 0 0.0022475016 0.005699407 0.038798780 0.0110994975
#> 57 12.0 0 0.0025773798 0.007753140 0.046387337 0.0150343326
#> age prior _baseline_ _permutation_ label
#> 1 0.002751134 0.0002786377 0.005572104 0 ranger
#> 12 0.011073665 0.0006944964 0.018361833 0 ranger
#> 23 0.014504751 0.0003502657 0.025381473 0 ranger
#> 35 0.014323920 0.0009249015 0.039236130 0 ranger
#> 48 0.015248906 0.0012535640 0.051156514 0 ranger
#> 57 0.015342076 0.0023220928 0.060363730 0 ranger
plot(cph_model_parts_brier, rsf_ranger_model_parts)
# }