This is the main function of the package. It does the expensive calculations behind the following H-statistics:

  • Total interaction strength \(H^2\), a statistic measuring the proportion of prediction variability unexplained by main effects of v, see h2() for details.

  • Friedman and Popescu's statistic \(H^2_j\) of overall interaction strength per feature, see h2_overall() for details.

  • Friedman and Popescu's statistic \(H^2_{jk}\) of pairwise interaction strength, see h2_pairwise() for details.

  • Friedman and Popescu's statistic \(H^2_{jkl}\) of three-way interaction strength, see h2_threeway() for details. To save time, this statistic is not calculated by default. Set threeway_m to a value above 2 to get three-way statistics of the threeway_m variables with strongest overall interaction.

Furthermore, it allows to calculate an experimental partial dependence based measure of feature importance, \(\textrm{PDI}_j^2\). It equals the proportion of prediction variability unexplained by other features, see pd_importance() for details. This statistic is not shown by summary() or plot().

Instead of using summary(), interaction statistics can also be obtained via the more flexible functions h2(), h2_overall(), h2_pairwise(), and h2_threeway().

hstats(object, ...)

# Default S3 method
hstats(
  object,
  X,
  v = NULL,
  pred_fun = stats::predict,
  pairwise_m = 5L,
  threeway_m = 0L,
  approx = FALSE,
  grid_size = 50L,
  n_max = 500L,
  eps = 1e-10,
  w = NULL,
  verbose = TRUE,
  ...
)

# S3 method for class 'ranger'
hstats(
  object,
  X,
  v = NULL,
  pred_fun = NULL,
  pairwise_m = 5L,
  threeway_m = 0L,
  approx = FALSE,
  grid_size = 50L,
  n_max = 500L,
  eps = 1e-10,
  w = NULL,
  verbose = TRUE,
  survival = c("chf", "prob"),
  ...
)

# S3 method for class 'explainer'
hstats(
  object,
  X = object[["data"]],
  v = NULL,
  pred_fun = object[["predict_function"]],
  pairwise_m = 5L,
  threeway_m = 0L,
  approx = FALSE,
  grid_size = 50L,
  n_max = 500L,
  eps = 1e-10,
  w = object[["weights"]],
  verbose = TRUE,
  ...
)

Arguments

object

Fitted model object.

...

Additional arguments passed to pred_fun(object, X, ...), for instance type = "response" in a glm() model, or reshape = TRUE in a multiclass XGBoost model.

X

A data.frame or matrix serving as background dataset.

v

Vector of feature names. The default (NULL) will use all column names of X except the column name of the optional case weight w (if specified as name).

pred_fun

Prediction function of the form function(object, X, ...), providing \(K \ge 1\) predictions per row. Its first argument represents the model object, its second argument a data structure like X. Additional arguments (such as type = "response" in a GLM, or reshape = TRUE in a multiclass XGBoost model) can be passed via .... The default, stats::predict(), will work in most cases.

pairwise_m

Number of features for which pairwise statistics are to be calculated. The features are selected based on Friedman and Popescu's overall interaction strength \(H^2_j\). Set to to 0 to avoid pairwise calculations. For multivariate predictions, the union of the pairwise_m column-wise strongest variable names is taken. This can lead to very long run-times.

threeway_m

Like pairwise_m, but controls the feature count for three-way interactions. Cannot be larger than pairwise_m. To save computation time, the default is 0.

approx

Should quantile approximation be applied to dense numeric features? The default is FALSE. Setting this option to TRUE brings a massive speed-up for one-way calculations. It can, e.g., be used when the number of features is very large.

grid_size

Integer controlling the number of quantile midpoints used to approximate dense numerics. The quantile midpoints are calculated after subampling via n_max. Only relevant if approx = TRUE.

n_max

If X has more than n_max rows, a random sample of n_max rows is selected from X. In this case, set a random seed for reproducibility.

eps

Threshold below which numerator values are set to 0. Default is 1e-10.

w

Optional vector of case weights. Can also be a column name of X.

verbose

Should a progress bar be shown? The default is TRUE.

survival

Should cumulative hazards ("chf", default) or survival probabilities ("prob") per time be predicted? Only in ranger() survival models.

Value

An object of class "hstats" containing these elements:

  • X: Input X (sampled to n_max rows, after optional quantile approximation).

  • w: Case weight vector w (sampled to n_max values), or NULL.

  • v: Vector of column names in X for which overall H statistics have been calculated.

  • f: Matrix with (centered) predictions \(F\).

  • mean_f2: (Weighted) column means of f. Used to normalize \(H^2\) and \(H^2_j\).

  • F_j: List of matrices, each representing (centered) partial dependence functions \(F_j\).

  • F_not_j: List of matrices with (centered) partial dependence functions \(F_{\setminus j}\) of other features.

  • K: Number of columns of prediction matrix.

  • pred_names: Column names of prediction matrix.

  • pairwise_m: Like input pairwise_m, but capped at length(v).

  • threeway_m: Like input threeway_m, but capped at the smaller of length(v) and pairwise_m.

  • eps: Like input eps.

  • pd_importance: List with numerator and denominator of \(\textrm{PDI}_j\).

  • h2: List with numerator and denominator of \(H^2\).

  • h2_overall: List with numerator and denominator of \(H^2_j\).

  • v_pairwise: Subset of v with largest \(H^2_j\) used for pairwise calculations. Only if pairwise calculations have been done.

  • combs2: Named list of variable pairs for which pairwise partial dependence functions are available. Only if pairwise calculations have been done.

  • F_jk: List of matrices, each representing (centered) bivariate partial dependence functions \(F_{jk}\). Only if pairwise calculations have been done.

  • h2_pairwise: List with numerator and denominator of \(H^2_{jk}\). Only if pairwise calculations have been done.

  • v_threeway: Subset of v with largest h2_overall() used for three-way calculations. Only if three-way calculations have been done.

  • combs3: Named list of variable triples for which three-way partial dependence functions are available. Only if three-way calculations have been done.

  • F_jkl: List of matrices, each representing (centered) three-way partial dependence functions \(F_{jkl}\). Only if three-way calculations have been done.

  • h2_threeway: List with numerator and denominator of \(H^2_{jkl}\). Only if three-way calculations have been done.

Methods (by class)

  • hstats(default): Default hstats method.

  • hstats(ranger): Method for "ranger" models.

  • hstats(explainer): Method for DALEX "explainer".

References

Friedman, Jerome H., and Bogdan E. Popescu. "Predictive Learning via Rule Ensembles." The Annals of Applied Statistics 2, no. 3 (2008): 916-54.

See also

h2(), h2_overall(), h2_pairwise(), h2_threeway(), and pd_importance() for specific statistics calculated from the resulting object.

Examples

# MODEL 1: Linear regression
fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
s <- hstats(fit, X = iris[, -1])
#> 1-way calculations...
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
#> 2-way calculations...
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |======================================================================| 100%
s
#> 'hstats' object. Use plot() or summary() for details.
#> 
#> H^2 (normalized)
#> [1] 0.0502364
plot(s)

plot(s, zero = FALSE)  # Drop 0

summary(s)
#> *H^2 (normalized)
#> [1] 0.0502364
#> 
#> *Largest Overall H^2 (normalized)
#>  Petal.Width      Species  Sepal.Width Petal.Length 
#>    0.0502364    0.0502364    0.0000000    0.0000000 
#> 
#> *Largest Pairwise H^2 (normalized)
#>                                [,1]
#> Petal.Width:Species      0.05546172
#> Sepal.Width:Petal.Length 0.00000000
#> Sepal.Width:Petal.Width  0.00000000
#> 
  
# Absolute pairwise interaction strengths
h2_pairwise(s, normalize = FALSE, squared = FALSE, zero = FALSE)
#> Pairwise H (unnormalized)
#> Petal.Width:Species 
#>           0.1726312 

# MODEL 2: Multi-response linear regression
fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
s <- hstats(fit, X = iris[, 3:5], verbose = FALSE)
plot(s)

summary(s)
#> *H^2 (normalized)
#> Sepal.Length  Sepal.Width 
#>   0.04758952   0.03963575 
#> 
#> *Largest Overall H^2 (normalized)
#>              Sepal.Length Sepal.Width
#> Species        0.04758952  0.03963575
#> Petal.Width    0.04758952  0.03963575
#> Petal.Length   0.00000000  0.00000000
#> 
#> *Largest Pairwise H^2 (normalized)
#>                          Sepal.Length Sepal.Width
#> Petal.Width:Species        0.02937378  0.01637166
#> Petal.Length:Petal.Width   0.00000000  0.00000000
#> Petal.Length:Species       0.00000000  0.00000000
#> 

# MODEL 3: Gamma GLM with log link
fit <- glm(Sepal.Length ~ ., data = iris, family = Gamma(link = log))

# No interactions for additive features, at least on link scale
s <- hstats(fit, X = iris[, -1], verbose = FALSE)
summary(s)
#> *H^2 (normalized)
#> [1] 0
#> 
#> *Largest Overall H^2 (normalized)
#>  Sepal.Width Petal.Length  Petal.Width      Species 
#>            0            0            0            0 
#> 
#> *Largest Pairwise H^2 (normalized)
#>                          [,1]
#> Sepal.Width:Petal.Length    0
#> Sepal.Width:Petal.Width     0
#> Sepal.Width:Species         0
#> 

# On original scale, we have interactions everywhere. 
# To see three-way interactions, we set threeway_m to a value above 2.
s <- hstats(fit, X = iris[, -1], type = "response", threeway_m = 5)
#> 1-way calculations...
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
#> 2-way calculations...
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |============                                                          |  17%
  |                                                                            
  |=======================                                               |  33%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |===============================================                       |  67%
  |                                                                            
  |==========================================================            |  83%
  |                                                                            
  |======================================================================| 100%
#> 3-way calculations...
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
plot(s, ncol = 1)  # All three types use different denominators


# All statistics on same scale (of predictions)
plot(s, squared = FALSE, normalize = FALSE, facet_scale = "free_y")