The safely_transform_categorical() function calculates a transformation function for the categorical variable using predictions obtained from black box model and hierarchical clustering. The gap statistic criterion is used to determine the optimal number of clusters.

safely_transform_categorical(
  explainer,
  variable,
  method = "complete",
  B = 500,
  collapse = "_"
)

Arguments

explainer

DALEX explainer created with explain() function

variable

a feature for which the transformation function is to be computed

method

the agglomeration method to be used in hierarchical clustering, one of: "ward.D", "ward.D2", "single", "complete", "average", "mcquitty", "median", "centroid"

B

number of reference datasets used to calculate gap statistics

collapse

a character string to separate original levels while combining them to the new one

Value

list of information on the transformation of given variable

See also

Examples


library(DALEX)
library(randomForest)
library(rSAFE)

data <- apartments[1:500,]
set.seed(111)
model_rf <- randomForest(m2.price ~ construction.year + surface + floor +
                           no.rooms + district, data = data)
explainer_rf <- explain(model_rf, data = data[,2:6], y = data[,1])
#> Preparation of a new explainer is initiated
#>   -> model label       :  randomForest  (  default  )
#>   -> data              :  500  rows  5  cols 
#>   -> target variable   :  500  values 
#>   -> predict function  :  yhat.randomForest  will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package randomForest , ver. 4.7.1.1 , task regression (  default  ) 
#>   -> predicted values  :  numerical, min =  2010.939 , mean =  3502.345 , max =  5764.513  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  numerical, min =  -387.9388 , mean =  -0.6372461 , max =  749.0998  
#>   A new explainer has been created!  
safely_transform_categorical(explainer_rf, "district")
#> $clustering
#> 
#> Call:
#> hclust(d = dist_matrix, method = method)
#> 
#> Cluster method   : complete 
#> Distance         : euclidean 
#> Number of objects: 10 
#> 
#> 
#> $new_levels
#>       district                            district_new
#> 1       Bemowo Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
#> 2      Bielany Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
#> 3        Ursus Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
#> 4      Ursynow Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
#> 5        Praga Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
#> 6         Wola Bemowo_Bielany_Praga_Ursus_Ursynow_Wola
#> 7     Zoliborz                 Mokotow_Ochota_Zoliborz
#> 8      Mokotow                 Mokotow_Ochota_Zoliborz
#> 9       Ochota                 Mokotow_Ochota_Zoliborz
#> 10 Srodmiescie                             Srodmiescie
#>