Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017), and Covert and Lee (2021), abbreviated by CL21. For up to \(p=8\) features, the resulting Kernel SHAP values are exact regarding the selected background data. For larger \(p\), an almost exact hybrid algorithm combining exact calculations and iterative sampling is used, see Details.
Note that (exact) Kernel SHAP is only an approximation of (exact) permutation SHAP.
Thus, for up to eight features, we recommend permshap()
. For more features,
permshap()
is slow compared the optimized hybrid strategy of our Kernel SHAP
implementation.
kernelshap(object, ...)
# Default S3 method
kernelshap(
object,
X,
bg_X = NULL,
pred_fun = stats::predict,
feature_names = colnames(X),
bg_w = NULL,
bg_n = 200L,
exact = length(feature_names) <= 8L,
hybrid_degree = 1L + length(feature_names) %in% 4:16,
paired_sampling = TRUE,
m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)),
tol = 0.005,
max_iter = 100L,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
...
)
# S3 method for class 'ranger'
kernelshap(
object,
X,
bg_X = NULL,
pred_fun = NULL,
feature_names = colnames(X),
bg_w = NULL,
bg_n = 200L,
exact = length(feature_names) <= 8L,
hybrid_degree = 1L + length(feature_names) %in% 4:16,
paired_sampling = TRUE,
m = 2L * length(feature_names) * (1L + 3L * (hybrid_degree == 0L)),
tol = 0.005,
max_iter = 100L,
parallel = FALSE,
parallel_args = NULL,
verbose = TRUE,
survival = c("chf", "prob"),
...
)
Fitted model object.
Additional arguments passed to pred_fun(object, X, ...)
.
\((n \times p)\) matrix or data.frame
with rows to be explained.
The columns should only represent model features, not the response
(but see feature_names
on how to overrule this).
Background data used to integrate out "switched off" features,
often a subset of the training data (typically 50 to 500 rows).
In cases with a natural "off" value (like MNIST digits),
this can also be a single row with all values set to the off value.
If no bg_X
is passed (the default) and if X
is sufficiently large,
a random sample of bg_n
rows from X
serves as background data.
Prediction function of the form function(object, X, ...)
,
providing \(K \ge 1\) predictions per row. Its first argument
represents the model object
, its second argument a data structure like X
.
Additional (named) arguments are passed via ...
.
The default, stats::predict()
, will work in most cases.
Optional vector of column names in X
used to calculate
SHAP values. By default, this equals colnames(X)
. Not supported if X
is a matrix.
Optional vector of case weights for each row of bg_X
.
If bg_X = NULL
, must be of same length as X
. Set to NULL
for no weights.
If bg_X = NULL
: Size of background data to be sampled from X
.
If TRUE
, the algorithm will produce exact Kernel SHAP values
with respect to the background data. In this case, the arguments hybrid_degree
,
m
, paired_sampling
, tol
, and max_iter
are ignored.
The default is TRUE
up to eight features, and FALSE
otherwise.
Integer controlling the exactness of the hybrid strategy. For
\(4 \le p \le 16\), the default is 2, otherwise it is 1.
Ignored if exact = TRUE
.
0
: Pure sampling strategy not involving any exact part. It is strictly
worse than the hybrid strategy and should therefore only be used for
studying properties of the Kernel SHAP algorithm.
1
: Uses all \(2p\) on-off vectors \(z\) with \(\sum z \in \{1, p-1\}\)
for the exact part, which covers at least 75% of the mass of the Kernel weight
distribution. The remaining mass is covered by random sampling.
2
: Uses all \(p(p+1)\) on-off vectors \(z\) with
\(\sum z \in \{1, 2, p-2, p-1\}\). This covers at least 92% of the mass of the
Kernel weight distribution. The remaining mass is covered by sampling.
Convergence usually happens in the minimal possible number of iterations of two.
k>2
: Uses all on-off vectors with
\(\sum z \in \{1, \dots, k, p-k, \dots, p-1\}\).
Logical flag indicating whether to do the sampling in a paired
manner. This means that with every on-off vector \(z\), also \(1-z\) is
considered. CL21 shows its superiority compared to standard sampling, therefore the
default (TRUE
) should usually not be changed except for studying properties
of Kernel SHAP algorithms. Ignored if exact = TRUE
.
Even number of on-off vectors sampled during one iteration.
The default is \(2p\), except when hybrid_degree == 0
.
Then it is set to \(8p\). Ignored if exact = TRUE
.
Tolerance determining when to stop. Following CL21, the algorithm keeps
iterating until \(\textrm{max}(\sigma_n)/(\textrm{max}(\beta_n) - \textrm{min}(\beta_n)) < \textrm{tol}\),
where the \(\beta_n\) are the SHAP values of a given observation,
and \(\sigma_n\) their standard errors.
For multidimensional predictions, the criterion must be satisfied for each
dimension separately. The stopping criterion uses the fact that standard errors
and SHAP values are all on the same scale. Ignored if exact = TRUE
.
If the stopping criterion (see tol
) is not reached after
max_iter
iterations, the algorithm stops. Ignored if exact = TRUE
.
If TRUE
, use parallel foreach::foreach()
to loop over rows
to be explained. Must register backend beforehand, e.g., via 'doFuture' package,
see README for an example. Parallelization automatically disables the progress bar.
Named list of arguments passed to foreach::foreach()
.
Ideally, this is NULL
(default). Only relevant if parallel = TRUE
.
Example on Windows: if object
is a GAM fitted with package 'mgcv',
then one might need to set parallel_args = list(.packages = "mgcv")
.
Set to FALSE
to suppress messages and the progress bar.
Should cumulative hazards ("chf", default) or survival
probabilities ("prob") per time be predicted? Only in ranger()
survival models.
An object of class "kernelshap" with the following components:
S
: \((n \times p)\) matrix with SHAP values or, if the model output has
dimension \(K > 1\), a list of \(K\) such matrices.
X
: Same as input argument X
.
baseline
: Vector of length K representing the average prediction on the
background data.
bg_X
: The background data.
bg_w
: The background case weights.
SE
: Standard errors corresponding to S
(and organized like S
).
n_iter
: Integer vector of length n providing the number of iterations
per row of X
.
converged
: Logical vector of length n indicating convergence per row of X
.
m
: Integer providing the effective number of sampled on-off vectors used
per iteration.
m_exact
: Integer providing the effective number of exact on-off vectors used
per iteration.
prop_exact
: Proportion of the Kernel SHAP weight distribution covered by
exact calculations.
exact
: Logical flag indicating whether calculations are exact or not.
txt
: Summary text.
predictions
: \((n \times K)\) matrix with predictions of X
.
algorithm
: "kernelshap".
The pure iterative Kernel SHAP sampling as in Covert and Lee (2021) works like this:
A binary "on-off" vector \(z\) is drawn from \(\{0, 1\}^p\) such that its sum follows the SHAP Kernel weight distribution (normalized to the range \(\{1, \dots, p-1\}\)).
For each \(j\) with \(z_j = 1\), the \(j\)-th column of the original background data is replaced by the corresponding feature value \(x_j\) of the observation to be explained.
The average prediction \(v_z\) on the data of Step 2 is calculated, and the average prediction \(v_0\) on the background data is subtracted.
Steps 1 to 3 are repeated \(m\) times. This produces a binary \(m \times p\) matrix \(Z\) (each row equals one of the \(z\)) and a vector \(v\) of shifted predictions.
\(v\) is regressed onto \(Z\) under the constraint that the sum of the coefficients equals \(v_1 - v_0\), where \(v_1\) is the prediction of the observation to be explained. The resulting coefficients are the Kernel SHAP values.
This is repeated multiple times until convergence, see CL21 for details.
A drawback of this strategy is that many (at least 75%) of the \(z\) vectors will have \(\sum z \in \{1, p-1\}\), producing many duplicates. Similarly, at least 92% of the mass will be used for the \(p(p+1)\) possible vectors with \(\sum z \in \{1, 2, p-2, p-1\}\). This inefficiency can be fixed by a hybrid strategy, combining exact calculations with sampling.
The hybrid algorithm has two steps:
Step 1 (exact part): There are \(2p\) different on-off vectors \(z\) with \(\sum z \in \{1, p-1\}\), covering a large proportion of the Kernel SHAP distribution. The degree 1 hybrid will list those vectors and use them according to their weights in the upcoming calculations. Depending on \(p\), we can also go a step further to a degree 2 hybrid by adding all \(p(p-1)\) vectors with \(\sum z \in \{2, p-2\}\) to the process etc. The necessary predictions are obtained along with other calculations similar to those described in CL21.
Step 2 (sampling part): The remaining weight is filled by sampling vectors z according to Kernel SHAP weights renormalized to the values not yet covered by Step 1. Together with the results from Step 1 - correctly weighted - this now forms a complete iteration as in CL21. The difference is that most mass is covered by exact calculations. Afterwards, the algorithm iterates until convergence. The output of Step 1 is reused in every iteration, leading to an extremely efficient strategy.
If \(p\) is sufficiently small, all possible \(2^p-2\) on-off vectors \(z\) can be
evaluated. In this case, no sampling is required and the algorithm returns exact
Kernel SHAP values with respect to the given background data.
Since kernelshap()
calculates predictions on data with \(MN\) rows
(\(N\) is the background data size and \(M\) the number of \(z\) vectors), \(p\)
should not be much higher than 10 for exact calculations.
For similar reasons, degree 2 hybrids should not use \(p\) much larger than 40.
kernelshap(default)
: Default Kernel SHAP method.
kernelshap(ranger)
: Kernel SHAP method for "ranger" models, see Readme for an example.
Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model predictions. Proceedings of the 31st International Conference on Neural Information Processing Systems, 2017.
Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021.
# MODEL ONE: Linear regression
fit <- lm(Sepal.Length ~ ., data = iris)
# Select rows to explain (only feature columns)
X_explain <- iris[-1]
# Calculate SHAP values
s <- kernelshap(fit, X_explain)
#> Exact Kernel SHAP values
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|====== | 8%
|
|====== | 9%
|
|======= | 9%
|
|======= | 10%
|
|======= | 11%
|
|======== | 11%
|
|======== | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============= | 18%
|
|============= | 19%
|
|============== | 19%
|
|============== | 20%
|
|============== | 21%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|==================== | 28%
|
|==================== | 29%
|
|===================== | 29%
|
|===================== | 30%
|
|===================== | 31%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 39%
|
|============================ | 40%
|
|============================ | 41%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 43%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 49%
|
|=================================== | 50%
|
|=================================== | 51%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 59%
|
|========================================== | 60%
|
|========================================== | 61%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 69%
|
|================================================= | 70%
|
|================================================= | 71%
|
|================================================== | 71%
|
|================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 79%
|
|======================================================== | 80%
|
|======================================================== | 81%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
|============================================================= | 87%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 89%
|
|=============================================================== | 90%
|
|=============================================================== | 91%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 99%
|
|======================================================================| 100%
s
#> SHAP values of first observations:
#> Sepal.Width Petal.Length Petal.Width Species
#> [1,] 0.21951350 -1.955357 0.3149451 0.5823533
#> [2,] -0.02843097 -1.955357 0.3149451 0.5823533
# MODEL TWO: Multi-response linear regression
fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
s <- kernelshap(fit, iris[3:5])
#> Exact Kernel SHAP values
#>
|
| | 0%
|
| | 1%
|
|= | 1%
|
|= | 2%
|
|== | 3%
|
|=== | 4%
|
|=== | 5%
|
|==== | 5%
|
|==== | 6%
|
|===== | 7%
|
|====== | 8%
|
|====== | 9%
|
|======= | 9%
|
|======= | 10%
|
|======= | 11%
|
|======== | 11%
|
|======== | 12%
|
|========= | 13%
|
|========== | 14%
|
|========== | 15%
|
|=========== | 15%
|
|=========== | 16%
|
|============ | 17%
|
|============= | 18%
|
|============= | 19%
|
|============== | 19%
|
|============== | 20%
|
|============== | 21%
|
|=============== | 21%
|
|=============== | 22%
|
|================ | 23%
|
|================= | 24%
|
|================= | 25%
|
|================== | 25%
|
|================== | 26%
|
|=================== | 27%
|
|==================== | 28%
|
|==================== | 29%
|
|===================== | 29%
|
|===================== | 30%
|
|===================== | 31%
|
|====================== | 31%
|
|====================== | 32%
|
|======================= | 33%
|
|======================== | 34%
|
|======================== | 35%
|
|========================= | 35%
|
|========================= | 36%
|
|========================== | 37%
|
|=========================== | 38%
|
|=========================== | 39%
|
|============================ | 39%
|
|============================ | 40%
|
|============================ | 41%
|
|============================= | 41%
|
|============================= | 42%
|
|============================== | 43%
|
|=============================== | 44%
|
|=============================== | 45%
|
|================================ | 45%
|
|================================ | 46%
|
|================================= | 47%
|
|================================== | 48%
|
|================================== | 49%
|
|=================================== | 49%
|
|=================================== | 50%
|
|=================================== | 51%
|
|==================================== | 51%
|
|==================================== | 52%
|
|===================================== | 53%
|
|====================================== | 54%
|
|====================================== | 55%
|
|======================================= | 55%
|
|======================================= | 56%
|
|======================================== | 57%
|
|========================================= | 58%
|
|========================================= | 59%
|
|========================================== | 59%
|
|========================================== | 60%
|
|========================================== | 61%
|
|=========================================== | 61%
|
|=========================================== | 62%
|
|============================================ | 63%
|
|============================================= | 64%
|
|============================================= | 65%
|
|============================================== | 65%
|
|============================================== | 66%
|
|=============================================== | 67%
|
|================================================ | 68%
|
|================================================ | 69%
|
|================================================= | 69%
|
|================================================= | 70%
|
|================================================= | 71%
|
|================================================== | 71%
|
|================================================== | 72%
|
|=================================================== | 73%
|
|==================================================== | 74%
|
|==================================================== | 75%
|
|===================================================== | 75%
|
|===================================================== | 76%
|
|====================================================== | 77%
|
|======================================================= | 78%
|
|======================================================= | 79%
|
|======================================================== | 79%
|
|======================================================== | 80%
|
|======================================================== | 81%
|
|========================================================= | 81%
|
|========================================================= | 82%
|
|========================================================== | 83%
|
|=========================================================== | 84%
|
|=========================================================== | 85%
|
|============================================================ | 85%
|
|============================================================ | 86%
|
|============================================================= | 87%
|
|============================================================== | 88%
|
|============================================================== | 89%
|
|=============================================================== | 89%
|
|=============================================================== | 90%
|
|=============================================================== | 91%
|
|================================================================ | 91%
|
|================================================================ | 92%
|
|================================================================= | 93%
|
|================================================================== | 94%
|
|================================================================== | 95%
|
|=================================================================== | 95%
|
|=================================================================== | 96%
|
|==================================================================== | 97%
|
|===================================================================== | 98%
|
|===================================================================== | 99%
|
|======================================================================| 99%
|
|======================================================================| 100%
s
#> SHAP values of first observations:
#> $Sepal.Length
#> Petal.Length Petal.Width Species
#> [1,] -2.13622 0.005991405 1.237003
#> [2,] -2.13622 0.005991405 1.237003
#>
#> $Sepal.Width
#> Petal.Length Petal.Width Species
#> [1,] -0.3647252 -0.62303 1.320153
#> [2,] -0.3647252 -0.62303 1.320153
#>
# Note 1: Feature columns can also be selected 'feature_names'
# Note 2: Especially when X is small, pass a sufficiently large background data bg_X
s <- kernelshap(
fit,
iris[1:4, ],
bg_X = iris,
feature_names = c("Petal.Length", "Petal.Width", "Species")
)
#> Exact Kernel SHAP values
#>
|
| | 0%
|
|================== | 25%
|
|=================================== | 50%
|
|==================================================== | 75%
|
|======================================================================| 100%
s
#> SHAP values of first observations:
#> $Sepal.Length
#> Petal.Length Petal.Width Species
#> [1,] -2.13622 0.005991405 1.237003
#> [2,] -2.13622 0.005991405 1.237003
#>
#> $Sepal.Width
#> Petal.Length Petal.Width Species
#> [1,] -0.3647252 -0.62303 1.320153
#> [2,] -0.3647252 -0.62303 1.320153
#>