Convert your ranger model into a standardized representation.
The returned representation is easy to be interpreted by the user and ready to be used as an argument in treeshap()
function.
ranger_surv.unify(
rf_model,
data,
type = c("risk", "survival", "chf"),
times = NULL
)
An object of ranger
class. At the moment, models built on data with categorical features
are not supported - please encode them before training.
Reference dataset. A data.frame
or matrix
with the same columns as in the training set of the model. Usually dataset used to train model.
A character to define the type of model prediction to use. Either "risk"
(default), which uses the risk score calculated as a sum of cumulative hazard function values, "survival"
, which uses the survival probability at certain time-points for each observation, or "chf"
, which used the cumulative hazard values at certain time-points for each observation.
A numeric vector of unique death times at which the prediction should be evaluated. By default unique.death.times
from model are used.
For type = "risk"
a unified model representation is returned - a model_unified.object
object. For type = "survival"
or type = "chf"
- a model_unified_multioutput.object
object is returned, which is a list that contains unified model representation (model_unified.object
object) for each time point. In this case, the list names are time points at which the survival function was evaluated.
The survival forest implemented in the ranger
package stores cumulative hazard
functions (CHFs) in the leaves of survival trees, as proposed for Random Survival Forests
(Ishwaran et al. 2008). The final model prediction is made by averaging these CHFs
from all the trees. To provide explanations in the form of a survival function,
the CHFs from the leaves are converted into survival functions (SFs) using
the formula SF(t) = exp(-CHF(t)).
However, it is important to note that averaging these SFs does not yield the correct
model prediction as the model prediction is the average of CHFs transformed in the same way.
Therefore, when you obtain explanations based on the survival function,
they are only proxies and may not be fully consistent with the model predictions
obtained using for example predict
function.
ranger.unify
for regression and classification ranger models
lightgbm.unify
for LightGBM models
gbm.unify
for GBM models
library(ranger)
data_colon <- data.table::data.table(survival::colon)
data_colon <- na.omit(data_colon[get("etype") == 2, ])
surv_cols <- c("status", "time", "rx")
feature_cols <- colnames(data_colon)[3:(ncol(data_colon) - 1)]
train_x <- model.matrix(
~ -1 + .,
data_colon[, .SD, .SDcols = setdiff(feature_cols, surv_cols[1:2])]
)
train_y <- survival::Surv(
event = (data_colon[, get("status")] |>
as.character() |>
as.integer()),
time = data_colon[, get("time")],
type = "right"
)
rf <- ranger::ranger(
x = train_x,
y = train_y,
data = data_colon,
max.depth = 10,
num.trees = 10
)
unified_model_risk <- ranger_surv.unify(rf, train_x, type = "risk")
shaps <- treeshap(unified_model_risk, train_x[1:2,])
#>
|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)
====================================----------------------------------- (50%)
======================================================================= (100%)
# compute shaps for 3 selected time points
unified_model_surv <- ranger_surv.unify(rf, train_x, type = "survival", times = c(23, 50, 73))
shaps_surv <- treeshap(unified_model_surv, train_x[1:2,])
#>
|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)
====================================----------------------------------- (50%)
======================================================================= (100%)
|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)
====================================----------------------------------- (50%)
======================================================================= (100%)
|0%----|------|20%---|------|40%---|------|60%---|------|80%---|------|100%
#> =---------------------------------------------------------------------- (0%)
====================================----------------------------------- (50%)
======================================================================= (100%)