Scatterplot of two features, showing the sum of their SHAP values on the color scale. This allows to visualize the combined effect of two features, including interactions. A typical application are models with latitude and longitude as features (plus maybe other regional features that can be passed via add_vars).

If SHAP interaction values are available, setting interactions = TRUE allows to focus on pure interaction effects (multiplied by two). In this case, add_vars has no effect.

sv_dependence2D(object, ...)

# Default S3 method
sv_dependence2D(object, ...)

# S3 method for class 'shapviz'
sv_dependence2D(
  object,
  x,
  y,
  viridis_args = getOption("shapviz.viridis_args"),
  jitter_width = NULL,
  jitter_height = NULL,
  interactions = FALSE,
  add_vars = NULL,
  ...
)

# S3 method for class 'mshapviz'
sv_dependence2D(
  object,
  x,
  y,
  viridis_args = getOption("shapviz.viridis_args"),
  jitter_width = NULL,
  jitter_height = NULL,
  interactions = FALSE,
  add_vars = NULL,
  ...
)

Arguments

object

An object of class "(m)shapviz".

...

Arguments passed to ggplot2::geom_jitter().

x

Feature name for x axis. Can be a vector/list if object is of class "shapviz".

y

Feature name for y axis. Can be a vector/list if object is of class "shapviz".

viridis_args

List of viridis color scale arguments, see ?ggplot2::scale_color_viridis_c. The default points to the global option shapviz.viridis_args, which corresponds to list(begin = 0.25, end = 0.85, option = "inferno"). These values are passed to ggplot2::scale_color_viridis_*(). For example, to switch to a standard viridis scale, you can either change the default via options(shapviz.viridis_args = list()), or set viridis_args = list(). Only relevant if color_var is not NULL.

jitter_width

The amount of horizontal jitter. The default (NULL) will use a value of 0.2 in case v is discrete, and no jitter otherwise. (Numeric variables are considered discrete if they have at most 7 unique values.) Can be a vector/list if v is a vector.

jitter_height

Similar to jitter_width for vertical scatter.

interactions

Should SHAP interaction values be plotted? The default (FALSE) will show the rowwise sum of the SHAP values of x and y. If TRUE, will use twice the SHAP interaction value (requires SHAP interactions).

add_vars

Optional vector of feature names, whose SHAP values should be added to the sum of the SHAP values of x and y (only if interactions = FALSE). A use case would be a model with geographic x and y coordinates, along with some additional locational features like distance to the next train station.

Value

An object of class "ggplot" (or "patchwork") representing a dependence plot.

Methods (by class)

  • sv_dependence2D(default): Default method.

  • sv_dependence2D(shapviz): 2D SHAP dependence plot for "shapviz" object.

  • sv_dependence2D(mshapviz): 2D SHAP dependence plot for "mshapviz" object.

See also

Examples

dtrain <- xgboost::xgb.DMatrix(
  data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
)
fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
sv <- shapviz(fit, X_pred = dtrain, X = iris)
sv_dependence2D(sv, x = "Petal.Length", y = "Species")

sv_dependence2D(sv, x = c("Petal.Length", "Species"), y = "Sepal.Width")


# SHAP interaction values
sv2 <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
sv_dependence2D(sv2, x = "Petal.Length", y = "Species", interactions = TRUE)

sv_dependence2D(
  sv2, x = "Petal.Length", y = c("Species", "Petal.Width"), interactions = TRUE
)


# mshapviz object
mx <- split(sv, f = iris$Species)
sv_dependence2D(mx, x = "Petal.Length", y = "Sepal.Width")