R/predict_profile.R
predict_profile.surv_explainer.Rd
This function calculates Ceteris Paribus Profiles for a specific observation with the possibility to take the time dimension into account.
predict_profile(
explainer,
new_observation,
variables = NULL,
categorical_variables = NULL,
...,
type = "ceteris_paribus",
output_type = "survival",
variable_splits_type = "uniform",
center = FALSE
)
# S3 method for surv_explainer
predict_profile(
explainer,
new_observation,
variables = NULL,
categorical_variables = NULL,
...,
type = "ceteris_paribus",
output_type = "survival",
variable_splits_type = "uniform",
center = FALSE
)
an explainer object - model preprocessed by the explain()
function
a new observation for which the prediction need to be explained
a character vector containing names of variables to be explained
a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the variables
argument, they will be added at the end.
additional parameters passed to DALEX::predict_profile
if output_type =="risk"
character, only "ceteris_paribus"
is implemented
either "survival"
, "chf"
or "risk"
the type of survival model output that should be considered for explanations. If "survival"
the explanations are based on the survival function. If "chf"
the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the DALEX::predict_profile
function.
character, decides how variable grids should be calculated. Use "quantiles"
for percentiles or "uniform"
(default) to get uniform grid of points.
logical, should profiles be centered around the average prediction
An object of class c("predict_profile_survival", "surv_ceteris_paribus")
. It is a list with the final result in the result
element.
# \donttest{
library(survival)
library(survex)
cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE)
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
cph_exp <- explain(cph)
#> Preparation of a new explainer is initiated
#> -> model label : coxph ( default )
#> -> data : 137 rows 6 cols ( extracted from the model )
#> -> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : predict.coxph with type = 'risk' will be used ( default )
#> -> predict survival function : predictSurvProb.coxph will be used ( default )
#> -> predict cumulative hazard function : -log(predict_survival_function) will be used ( default )
#> -> model_info : package survival , ver. 3.7.0 , task survival ( default )
#> A new explainer has been created!
rsf_src_exp <- explain(rsf_src)
#> Preparation of a new explainer is initiated
#> -> model label : rfsrc ( default )
#> -> data : 137 rows 6 cols ( extracted from the model )
#> -> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : sum over the predict_cumulative_hazard_function will be used ( default )
#> -> predict survival function : stepfun based on predict.rfsrc()$survival will be used ( default )
#> -> predict cumulative hazard function : stepfun based on predict.rfsrc()$chf will be used ( default )
#> -> model_info : package randomForestSRC , ver. 3.2.3 , task survival ( default )
#> A new explainer has been created!
cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)],
variables = c("trt", "celltype", "karno", "age"),
categorical_variables = "trt"
)
plot(cph_predict_profile, facet_ncol = 2)
rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5, -c(3, 4)], variables = "karno")
plot(cph_predict_profile, numerical_plot_type = "contours")
# }