The `train()` function is the core function of this package. The only obligatory arguments are `data` and `target`. Setting and changing other arguments will affect model validation strategy, tested model families, and so on.

train(
  data,
  y,
  type = "auto",
  engine = c("ranger", "xgboost", "decision_tree", "lightgbm"),
  verbose = TRUE,
  train_test_split = c(0.6, 0.2, 0.2),
  bayes_iter = 10,
  random_evals = 10,
  advanced_preprocessing = FALSE,
  metrics = "auto",
  sort_by = "auto",
  metric_function = NULL,
  metric_function_name = NULL,
  metric_function_decreasing = TRUE,
  best_model_number = 5
)

Arguments

data

A `data.frame` or `matrix` - data which will be used to build models. By default model will be trained on all columns in the `data`.

y

A target variable. It can be either (1) a vector of the same number of observations as `data` or (2) a character name of variable in the `data` that contains the target variable.

type

A character, one of `binary_clf`/`regression`/`guess` that sets the type of the task. If `guess` (the default option) then forester will figure out `type` based on the number of unique values in the `y` variable.

engine

A vector of tree-based models that shall be tested. Possible values are: `ranger`, `xgboost`, `decision_tree`, `lightgbm`, `catboost`. All models from this vector will be trained and the best one will be returned.

verbose

A logical value, if set to TRUE, provides all information about training process, if FALSE gives none.

train_test_split

A 3-value vector, describing the proportions of train, test, validation subsets to original data set. Default values are: c(0.6, 0.2, 0.2).

bayes_iter

An integer value describing number of optimization rounds used by the Bayesian optimization.

random_evals

An integer value describing number of trained models with different parameters by random search.

advanced_preprocessing

A logical value describing, whether the user wants to use advanced preprocessing methods (ex. deleting correlated values).

metrics

A vector of metrics names. By default param set for `auto`, most important metrics are returned. For `all` all metrics are returned. For `NULL` no metrics returned but still sorted by `sort_by`.

sort_by

A string with a name of metric to sort by. For `auto` models going to be sorted by `mse` for regression and `f1` for classification.

metric_function

The self-created function. It should look like name(predictions, observed) and return the numeric value. In case of using `metrics` param with a value other than `auto` or `all`, is needed to use a value `metric_function` in order to see given metric in report. If `sort_by` is equal to `auto` models are sorted by `metric_function`.

metric_function_name

The name of the column with values of `metric_function` parameter. By default `metric_function_name` is `metric_function`.

metric_function_decreasing

A logical value indicating how metric_function should be sorted. `TRUE` by default.

best_model_number

Number best models to be chosen as element of the return. All trained models will be returned as different element of the return.

Value

A list of all necessary objects for other functions. It contains: `type` The type of the ML task. If the user did not specify a type in the input parameters, the algorithm recognizes, uses and returns the same type. It could be 'regression' or 'classification'. `deleted_columns` Column names from the original data frame that have been removed in the data preprocessing process, e.g. due to too high correlation with other columns. `preprocessed_data` The data frame after the preprocessing process - that means: removing columns with one value for all rows, binarizing the target column, managing missing values and in advanced preprocessing: deleting correlated values, deleting columns that are ID-like columns and performing Boruta algorithm for selecting most important features. `bin_labels` Labels of binarized target value - 1, 2 values for binary classification and NULL for regression. `train_data` The training dataset - the part of the source dataset after preprocessing, balancing and splitting into the training, test and validation datasets. `test_data` The test dataset - the part of the source dataset after preprocessing, balancing and splitting into the training, test and validation datasets. `valid_data` The validation dataset - the part of the source dataset after preprocessing, balancing and splitting into the training, test and validation datasets. `predictions` Prediction list for all trained models based on the training dataset. `score_test` The list of metrics for all trained models calculated on a test dataset. For regression task there are: mse, r2 and mad metrics. For the classification task there are: f1, auc, recall, precision and accuracy. `score_train` The list of metrics for all trained models calculated on a train dataset. For regression task there are: mse, r2 and mad metrics. For the `score_valid` The list of metrics for all trained models calculated on a validation dataset. For regression task there are: mse, r2 and mad metrics. For the `models_list` The list of all trained models. `data` The original data. `y` The original target column name. `test_observed` Values of y column from the test dataset. `train_observed` Values of y column from the training dataset. `valid_observed` Values of y column from the validation dataset. `test_observed_labels` Values of y column from the test dataset as text labels (for classification task only). `train_observed_labels` Values of y column from the training dataset as text labels (for classification task only). `valid_observed_labels` Values of y column from the validation dataset as text labels (for classification task only). `best_models` Ranking list of top 10 trained models - with default parameters, with parameters optimized with the Bayesian optimization algorithm and with parameters optimized with the random search algorithm. `engine` The list of names of all types of trained models. Possible values: 'ranger', 'xgboost', 'decision_tree', 'lightgbm', 'catboost'. `predictions_all` Predictions for all trained models on a test dataset. `predictions_best` Predictions for models on a test dataset from best_models list. `predictions_all_labels` Predictions for all trained models on a test dataset as text labels for classification task only). `predictions_best_labels` Predictions for models on a test dataset from best_models list as labels (for classification task only). `predictions_train` Predictions for all trained models on a train dataset. `raw_train` The another form of the training dataset (useful for creating VS plot and predicting on training dataset for catboost and lightgbm models). `check_report` Data check report held as a list of strings. It is used by the `report()` function. `outliers` The vector of possible outliers detected by the `check_data()`.

Examples

library(forester)
data('lisbon')
train_output <- train(lisbon, 'Price')
#>  Type guessed as:  regression 
#> 
#>  -------------------- CHECK DATA REPORT -------------------- 
#>  
#> The dataset has 246 observations and 17 columns, which names are: 
#> Id; Condition; PropertyType; PropertySubType; Bedrooms; Bathrooms; AreaNet; AreaGross; Parking; Latitude; Longitude; Country; District; Municipality; Parish; Price.M2; Price; 
#> 
#> With the target value described by a column Price.
#> 
#>  Static columns are: 
#>  Country; District; Municipality; 
#> 
#>  With dominating values: 
#>  Portugal; Lisboa; Lisboa; 
#>  
#>  These column pairs are duplicate:
#>  District - Municipality; 
#> 
#>  No target values are missing. 
#> 
#>  No predictor values are missing. 
#> 
#>  No issues with dimensionality. 
#> 
#>  Strongly correlated, by Spearman rank, pairs of numerical values are: 
#>  
#>  Bedrooms - AreaNet: 0.77;
#>  Bedrooms - AreaGross: 0.77;
#>  Bathrooms - AreaNet: 0.78;
#>  Bathrooms - AreaGross: 0.78;
#>  AreaNet - AreaGross: 1;
#> 
#>  Strongly correlated, by Crammer's V rank, pairs of categorical values are: 
#>  PropertyType - PropertySubType: 1;
#> 
#>  These obserwation migth be outliers due to their numerical columns values: 
#>  145 146 196 44 5 51 57 58 59 60 61 62 63 64 69 75 76 77 78 ;
#> 
#>  Target data is not evenly distributed with quantile bins: 0.25 0.35 0.14 0.26 
#> 
#>  Columns names suggest that some of them are IDs, removing them can improve the model.
#>  Suspicious columns are: Id .
#> 
#>  Columns data suggest that some of them are IDs, removing them can improve the model.
#>  Suspicious columns are: Id .
#> 
#>  -------------------- CHECK DATA REPORT END -------------------- 
#>  
#>  Data preprocessed. 
#>  Data split and balanced. 
#>  Correct formats prepared. 
#>  Models successfully trained. 
#>  Predicted successfully. 
train_output$score_valid
#>    no.                name        engine        tuning          mse          r2      mae
#> 1    2       xgboost_model       xgboost         basic 143685453859  0.60640199 135871.5
#> 2   46       xgboost_bayes       xgboost     bayes_opt 144796997342  0.60335713 138757.1
#> 3   47 decision_tree_bayes decision_tree     bayes_opt 167134585019  0.54216771 151383.7
#> 4   11         ranger_RS_7        ranger random_search 172355600654  0.52786576 146038.0
#> 5   12         ranger_RS_8        ranger random_search 175206764634  0.52005556 140453.4
#> 6    1        ranger_model        ranger         basic 176861131326  0.51552375 140467.5
#> 7   45        ranger_bayes        ranger     bayes_opt 177100208139  0.51486884 142764.0
#> 8   10         ranger_RS_6        ranger random_search 177696042483  0.51323667 150902.5
#> 9    8         ranger_RS_4        ranger random_search 178173236730  0.51192949 147778.7
#> 10   6         ranger_RS_2        ranger random_search 180738901681  0.50490136 143104.2
#> 11   4      lightgbm_model      lightgbm         basic 198600114778  0.45597408 163623.3
#> 12  43       lightgbm_RS_9      lightgbm random_search 198600114778  0.45597408 163623.3
#> 13  38       lightgbm_RS_4      lightgbm random_search 201251798549  0.44871032 169723.7
#> 14  39       lightgbm_RS_5      lightgbm random_search 201251798549  0.44871032 169723.7
#> 15   3 decision_tree_model decision_tree         basic 201497538451  0.44803716 176132.1
#> 16  25  decision_tree_RS_1 decision_tree random_search 201497538451  0.44803716 176132.1
#> 17  26  decision_tree_RS_2 decision_tree random_search 201497538451  0.44803716 176132.1
#> 18  27  decision_tree_RS_3 decision_tree random_search 201497538451  0.44803716 176132.1
#> 19  28  decision_tree_RS_4 decision_tree random_search 201497538451  0.44803716 176132.1
#> 20  29  decision_tree_RS_5 decision_tree random_search 201497538451  0.44803716 176132.1
#> 21  30  decision_tree_RS_6 decision_tree random_search 201497538451  0.44803716 176132.1
#> 22  31  decision_tree_RS_7 decision_tree random_search 201497538451  0.44803716 176132.1
#> 23  32  decision_tree_RS_8 decision_tree random_search 201497538451  0.44803716 176132.1
#> 24  33  decision_tree_RS_9 decision_tree random_search 201497538451  0.44803716 176132.1
#> 25  34 decision_tree_RS_10 decision_tree random_search 201497538451  0.44803716 176132.1
#> 26  48      lightgbm_bayes      lightgbm     bayes_opt 201568659750  0.44784234 162111.1
#> 27  35       lightgbm_RS_1      lightgbm random_search 202121875717  0.44632691 170545.3
#> 28  14        ranger_RS_10        ranger random_search 213760186314  0.41444605 187644.4
#> 29  42       lightgbm_RS_8      lightgbm random_search 214805128850  0.41158364 163852.3
#> 30  44      lightgbm_RS_10      lightgbm random_search 214805128850  0.41158364 163852.3
#> 31  37       lightgbm_RS_3      lightgbm random_search 215251555308  0.41036074 165520.9
#> 32  41       lightgbm_RS_7      lightgbm random_search 215251555308  0.41036074 165520.9
#> 33   5         ranger_RS_1        ranger random_search 216042730457  0.40819347 190711.9
#> 34   9         ranger_RS_5        ranger random_search 221397579341  0.39352492 193152.2
#> 35   7         ranger_RS_3        ranger random_search 226399456894  0.37982326 187923.1
#> 36  40       lightgbm_RS_6      lightgbm random_search 228517310575  0.37402181 175919.8
#> 37  13         ranger_RS_9        ranger random_search 229673753119  0.37085397 195580.2
#> 38  15        xgboost_RS_1       xgboost random_search 251159464876  0.31199809 200641.5
#> 39  22        xgboost_RS_8       xgboost random_search 251159464876  0.31199809 200641.5
#> 40  16        xgboost_RS_2       xgboost random_search 251529223453  0.31098521 202108.0
#> 41  24       xgboost_RS_10       xgboost random_search 251529223453  0.31098521 202108.0
#> 42  20        xgboost_RS_6       xgboost random_search 256059009039  0.29857675 202524.9
#> 43  36       lightgbm_RS_2      lightgbm random_search 257857650556  0.29364973 211581.6
#> 44  17        xgboost_RS_3       xgboost random_search 346948222547  0.04960364 315084.6
#> 45  21        xgboost_RS_7       xgboost random_search 346948222547  0.04960364 315084.6
#> 46  19        xgboost_RS_5       xgboost random_search 349598730125  0.04234310 313834.7
#> 47  23        xgboost_RS_9       xgboost random_search 349598730125  0.04234310 313834.7
#> 48  18        xgboost_RS_4       xgboost random_search 556458657778 -0.52430895 486418.6