Black-box models have vastly different structures. explain_survival()
returns an explainer object that can be further processed for creating
prediction explanations and their visualizations. This function is used to manually
create explainers for models not covered by the survex
package. For selected
models the extraction of information can be done automatically. To do
this, you can call the explain()
function for survival models from mlr3proba
, censored
,
randomForestSRC
, ranger
, survival
packages and any other model
with pec::predictSurvProb()
method.
explain_survival(
model,
data = NULL,
y = NULL,
predict_function = NULL,
predict_function_target_column = NULL,
residual_function = NULL,
weights = NULL,
...,
label = NULL,
verbose = TRUE,
colorize = !isTRUE(getOption("knitr.in.progress")),
model_info = NULL,
type = NULL,
times = NULL,
times_generation = "survival_quantiles",
predict_survival_function = NULL,
predict_cumulative_hazard_function = NULL
)
explain(
model,
data = NULL,
y = NULL,
predict_function = NULL,
predict_function_target_column = NULL,
residual_function = NULL,
weights = NULL,
...,
label = NULL,
verbose = TRUE,
colorize = !isTRUE(getOption("knitr.in.progress")),
model_info = NULL,
type = NULL
)
# S3 method for default
explain(
model,
data = NULL,
y = NULL,
predict_function = NULL,
predict_function_target_column = NULL,
residual_function = NULL,
weights = NULL,
...,
label = NULL,
verbose = TRUE,
colorize = !isTRUE(getOption("knitr.in.progress")),
model_info = NULL,
type = NULL
)
object - a survival model to be explained
data.frame - data which will be used to calculate the explanations. If not provided, then it will be extracted from the model if possible. It should not contain the target columns. NOTE: If the target variable is present in the data
some functionality breaks.
survival::Surv
object containing event/censoring times and statuses corresponding to data
function taking 2 arguments - model
and newdata
and returning a single number for each observation - risk score. Observations with higher score are more likely to observe the event sooner.
unused, left for compatibility with DALEX
unused, left for compatibility with DALEX
unused, left for compatibility with DALEX
additional arguments, passed to DALEX::explain()
character - the name of the model. Used to differentiate on visualizations with multiple explainers. By default it's extracted from the 'class' attribute of the model if possible.
logical, if TRUE (default) then diagnostic messages will be printed
logical, if TRUE (default) then WARNINGS, ERRORS and NOTES are colorized. Will work only in the R console. By default it is FALSE while knitting and TRUE otherwise.
a named list (package
, version
, type
) containing information about model. If NULL
, survex
will seek for information on its own.
type of a model, by default "survival"
numeric, a vector of times at which the survival function and cumulative hazard function should be evaluated for calculations
either "survival_quantiles"
, "uniform"
or "quantiles"
. Sets the way of generating the vector of times based on times provided in the y
parameter. If "survival_quantiles"
the vector contains unique time points out of 50 uniformly distributed survival quantiles based on the Kaplan-Meier estimator, and additional time point being the median survival time (if possible); if "uniform"
the vector contains 50 equally spaced time points between the minimum and maximum observed times; if "quantiles"
the vector contains unique time points out of 50 time points between 0th and 98th percentiles of observed times. Ignored if times
is not NULL
.
function taking 3 arguments model
, newdata
and times
, and returning a matrix whose each row is a survival function evaluated at times
for one observation from newdata
function taking 3 arguments model
, newdata
and times
, and returning a matrix whose each row is a cumulative hazard function evaluated at times
for one observation from newdata
It is a list containing the following elements:
model
- the explained model.
data
- the dataset used for training.
y
- response for observations from data
.
residuals
- calculated residuals.
predict_function
- function that may be used for model predictions, shall return a single numerical value for each observation.
residual_function
- function that returns residuals, shall return a single numerical value for each observation.
class
- class/classes of a model.
label
- label of explainer.
model_info
- named list containing basic information about model, like package, version of package and type.
times
- a vector of times, that are used for evaluation of survival function and cumulative hazard function by default
predict_survival_function
- function that is used for model predictions in the form of survival function
predict_cumulative_hazard_function
- function that is used for model predictions in the form of cumulative hazard function
# \donttest{
library(survival)
library(survex)
cph <- survival::coxph(survival::Surv(time, status) ~ .,
data = veteran,
model = TRUE, x = TRUE
)
cph_exp <- explain(cph)
#> Preparation of a new explainer is initiated
#> -> model label : coxph ( default )
#> -> data : 137 rows 6 cols ( extracted from the model )
#> -> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : predict.coxph with type = 'risk' will be used ( default )
#> -> predict survival function : predictSurvProb.coxph will be used ( default )
#> -> predict cumulative hazard function : -log(predict_survival_function) will be used ( default )
#> -> model_info : package survival , ver. 3.7.0 , task survival ( default )
#> A new explainer has been created!
rsf_ranger <- ranger::ranger(survival::Surv(time, status) ~ .,
data = veteran,
respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5
)
rsf_ranger_exp <- explain(rsf_ranger,
data = veteran[, -c(3, 4)],
y = Surv(veteran$time, veteran$status)
)
#> Preparation of a new explainer is initiated
#> -> model label : ranger ( default )
#> -> data : 137 rows 6 cols
#> -> target variable : 137 values ( 128 events and 9 censored )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : sum over the predict_cumulative_hazard_function will be used ( default )
#> -> predict survival function : stepfun based on predict.ranger()$survival will be used ( default )
#> -> predict cumulative hazard function : stepfun based on predict.ranger()$chf will be used ( default )
#> -> model_info : package ranger , ver. 0.16.0 , task survival ( default )
#> A new explainer has been created!
rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran)
rsf_src_exp <- explain(rsf_src)
#> Preparation of a new explainer is initiated
#> -> model label : rfsrc ( default )
#> -> data : 137 rows 6 cols ( extracted from the model )
#> -> target variable : 137 values ( 128 events and 9 censored , censoring rate = 0.066 ) ( extracted from the model )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : sum over the predict_cumulative_hazard_function will be used ( default )
#> -> predict survival function : stepfun based on predict.rfsrc()$survival will be used ( default )
#> -> predict cumulative hazard function : stepfun based on predict.rfsrc()$chf will be used ( default )
#> -> model_info : package randomForestSRC , ver. 3.2.3 , task survival ( default )
#> A new explainer has been created!
library(censored, quietly = TRUE)
bt <- parsnip::boost_tree() %>%
parsnip::set_engine("mboost") %>%
parsnip::set_mode("censored regression") %>%
generics::fit(survival::Surv(time, status) ~ ., data = veteran)
bt_exp <- explain(bt, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status))
#> Preparation of a new explainer is initiated
#> -> model label : model_fit_blackboost ( default )
#> -> data : 137 rows 6 cols
#> -> target variable : 137 values ( 128 events and 9 censored )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : predict.model_fit with type = 'linear_pred' will be used ( default )
#> -> predict survival function : predict.model_fit with type = 'survival' will be used ( default )
#> -> predict cumulative hazard function : -log(predict_survival_function) will be used ( default )
#> -> model_info : package parsnip , ver. 1.2.1 , task survival ( default )
#> A new explainer has been created!
###### explain_survival() ######
cph <- coxph(Surv(time, status) ~ ., data = veteran)
veteran_data <- veteran[, -c(3, 4)]
veteran_y <- Surv(veteran$time, veteran$status)
risk_pred <- function(model, newdata) predict(model, newdata, type = "risk")
surv_pred <- function(model, newdata, times) pec::predictSurvProb(model, newdata, times)
chf_pred <- function(model, newdata, times) -log(surv_pred(model, newdata, times))
manual_cph_explainer <- explain_survival(
model = cph,
data = veteran_data,
y = veteran_y,
predict_function = risk_pred,
predict_survival_function = surv_pred,
predict_cumulative_hazard_function = chf_pred,
label = "manual coxph"
)
#> Preparation of a new explainer is initiated
#> -> model label : manual coxph
#> -> data : 137 rows 6 cols
#> -> target variable : 137 values ( 128 events and 9 censored )
#> -> times : 50 unique time points , min = 1.5 , median survival time = 80 , max = 999
#> -> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
#> -> predict function : risk_pred
#> -> predict survival function : surv_pred
#> -> predict cumulative hazard function : chf_pred
#> -> model_info : package survival , ver. 3.7.0 , task survival ( default )
#> A new explainer has been created!
# }