Convert your ranger model into a standardized representation. The returned representation is easy to be interpreted by the user and ready to be used as an argument in treeshap() function.

ranger_surv.unify(
  rf_model,
  data,
  type = c("risk", "survival", "chf"),
  times = NULL
)

Arguments

rf_model

An object of ranger class. At the moment, models built on data with categorical features are not supported - please encode them before training.

data

Reference dataset. A data.frame or matrix with the same columns as in the training set of the model. Usually dataset used to train model.

type

A character to define the type of model prediction to use. Either "risk" (default), which uses the risk score calculated as a sum of cumulative hazard function values, "survival", which uses the survival probability at certain time-points for each observation, or "chf", which used the cumulative hazard values at certain time-points for each observation.

times

A numeric vector of unique death times at which the prediction should be evaluated. By default unique.death.times from model are used.

Value

For type = "risk" a unified model representation is returned - a model_unified.object object. For type = "survival" or type = "chf" - a model_unified_multioutput.object object is returned, which is a list that contains unified model representation (model_unified.object object) for each time point. In this case, the list names are time points at which the survival function was evaluated.

Details

The survival forest implemented in the ranger package stores cumulative hazard functions (CHFs) in the leaves of survival trees, as proposed for Random Survival Forests (Ishwaran et al. 2008). The final model prediction is made by averaging these CHFs from all the trees. To provide explanations in the form of a survival function, the CHFs from the leaves are converted into survival functions (SFs) using the formula SF(t) = exp(-CHF(t)). However, it is important to note that averaging these SFs does not yield the correct model prediction as the model prediction is the average of CHFs transformed in the same way. Therefore, when you obtain explanations based on the survival function, they are only proxies and may not be fully consistent with the model predictions obtained using for example predict function.

Examples


library(ranger)
data_colon <- data.table::data.table(survival::colon)
data_colon <- na.omit(data_colon[get("etype") == 2, ])
surv_cols <- c("status", "time", "rx")

feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)]

train_x <- model.matrix(
  ~ -1 + .,
  data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
)
train_y <- survival::Surv(
  event = (data_colon[, get("status")] |>
             as.character() |>
             as.integer()),
  time = data_colon[, get("time")],
  type = "right"
)

rf <- ranger::ranger(
  x = train_x,
  y = train_y,
  data = data_colon,
  max.depth = 10,
  num.trees = 10
)
unified_model_risk <- ranger_surv.unify(rf, train_x, type = "risk")
shaps <- treeshap(unified_model_risk, train_x[1:2,])
#> 
|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)

====================================----------------------------------- (50%)

======================================================================= (100%)


# compute shaps for 3 selected time points
unified_model_surv <- ranger_surv.unify(rf, train_x, type = "survival", times = c(23, 50, 73))
shaps_surv <- treeshap(unified_model_surv, train_x[1:2,])
#> 
|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)

====================================----------------------------------- (50%)

======================================================================= (100%)

|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)

====================================----------------------------------- (50%)

======================================================================= (100%)

|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)

====================================----------------------------------- (50%)

======================================================================= (100%)