The `train()` function is the core function of this package. The only obligatory arguments are `data` and `target`. Setting and changing other arguments will affect model validation strategy, tested model families, and so on.
train(
data,
y,
type = "auto",
engine = c("ranger", "xgboost", "decision_tree", "lightgbm"),
verbose = TRUE,
train_test_split = c(0.6, 0.2, 0.2),
bayes_iter = 10,
random_evals = 10,
advanced_preprocessing = FALSE,
metrics = "auto",
sort_by = "auto",
metric_function = NULL,
metric_function_name = NULL,
metric_function_decreasing = TRUE,
best_model_number = 5
)
A `data.frame` or `matrix` - data which will be used to build models. By default model will be trained on all columns in the `data`.
A target variable. It can be either (1) a vector of the same number of observations as `data` or (2) a character name of variable in the `data` that contains the target variable.
A character, one of `binary_clf`/`regression`/`guess` that sets the type of the task. If `guess` (the default option) then forester will figure out `type` based on the number of unique values in the `y` variable.
A vector of tree-based models that shall be tested. Possible values are: `ranger`, `xgboost`, `decision_tree`, `lightgbm`, `catboost`. All models from this vector will be trained and the best one will be returned.
A logical value, if set to TRUE, provides all information about training process, if FALSE gives none.
A 3-value vector, describing the proportions of train, test, validation subsets to original data set. Default values are: c(0.6, 0.2, 0.2).
An integer value describing number of optimization rounds used by the Bayesian optimization.
An integer value describing number of trained models with different parameters by random search.
A logical value describing, whether the user wants to use advanced preprocessing methods (ex. deleting correlated values).
A vector of metrics names. By default param set for `auto`, most important metrics are returned. For `all` all metrics are returned. For `NULL` no metrics returned but still sorted by `sort_by`.
A string with a name of metric to sort by. For `auto` models going to be sorted by `mse` for regression and `f1` for classification.
The self-created function. It should look like name(predictions, observed) and return the numeric value. In case of using `metrics` param with a value other than `auto` or `all`, is needed to use a value `metric_function` in order to see given metric in report. If `sort_by` is equal to `auto` models are sorted by `metric_function`.
The name of the column with values of `metric_function` parameter. By default `metric_function_name` is `metric_function`.
A logical value indicating how metric_function should be sorted. `TRUE` by default.
Number best models to be chosen as element of the return. All trained models will be returned as different element of the return.
A list of all necessary objects for other functions. It contains: `type` The type of the ML task. If the user did not specify a type in the input parameters, the algorithm recognizes, uses and returns the same type. It could be 'regression' or 'classification'. `deleted_columns` Column names from the original data frame that have been removed in the data preprocessing process, e.g. due to too high correlation with other columns. `preprocessed_data` The data frame after the preprocessing process - that means: removing columns with one value for all rows, binarizing the target column, managing missing values and in advanced preprocessing: deleting correlated values, deleting columns that are ID-like columns and performing Boruta algorithm for selecting most important features. `bin_labels` Labels of binarized target value - 1, 2 values for binary classification and NULL for regression. `train_data` The training dataset - the part of the source dataset after preprocessing, balancing and splitting into the training, test and validation datasets. `test_data` The test dataset - the part of the source dataset after preprocessing, balancing and splitting into the training, test and validation datasets. `valid_data` The validation dataset - the part of the source dataset after preprocessing, balancing and splitting into the training, test and validation datasets. `predictions` Prediction list for all trained models based on the training dataset. `score_test` The list of metrics for all trained models calculated on a test dataset. For regression task there are: mse, r2 and mad metrics. For the classification task there are: f1, auc, recall, precision and accuracy. `score_train` The list of metrics for all trained models calculated on a train dataset. For regression task there are: mse, r2 and mad metrics. For the `score_valid` The list of metrics for all trained models calculated on a validation dataset. For regression task there are: mse, r2 and mad metrics. For the `models_list` The list of all trained models. `data` The original data. `y` The original target column name. `test_observed` Values of y column from the test dataset. `train_observed` Values of y column from the training dataset. `valid_observed` Values of y column from the validation dataset. `test_observed_labels` Values of y column from the test dataset as text labels (for classification task only). `train_observed_labels` Values of y column from the training dataset as text labels (for classification task only). `valid_observed_labels` Values of y column from the validation dataset as text labels (for classification task only). `best_models` Ranking list of top 10 trained models - with default parameters, with parameters optimized with the Bayesian optimization algorithm and with parameters optimized with the random search algorithm. `engine` The list of names of all types of trained models. Possible values: 'ranger', 'xgboost', 'decision_tree', 'lightgbm', 'catboost'. `predictions_all` Predictions for all trained models on a test dataset. `predictions_best` Predictions for models on a test dataset from best_models list. `predictions_all_labels` Predictions for all trained models on a test dataset as text labels for classification task only). `predictions_best_labels` Predictions for models on a test dataset from best_models list as labels (for classification task only). `predictions_train` Predictions for all trained models on a train dataset. `raw_train` The another form of the training dataset (useful for creating VS plot and predicting on training dataset for catboost and lightgbm models). `check_report` Data check report held as a list of strings. It is used by the `report()` function. `outliers` The vector of possible outliers detected by the `check_data()`.
library(forester)
data('lisbon')
train_output <- train(lisbon, 'Price')
#> ✔ Type guessed as: regression
#>
#> -------------------- CHECK DATA REPORT --------------------
#>
#> The dataset has 246 observations and 17 columns, which names are:
#> Id; Condition; PropertyType; PropertySubType; Bedrooms; Bathrooms; AreaNet; AreaGross; Parking; Latitude; Longitude; Country; District; Municipality; Parish; Price.M2; Price;
#>
#> With the target value described by a column Price.
#>
#> ✖ Static columns are:
#> Country; District; Municipality;
#>
#> ✖ With dominating values:
#> Portugal; Lisboa; Lisboa;
#>
#> ✖ These column pairs are duplicate:
#> District - Municipality;
#>
#> ✔ No target values are missing.
#>
#> ✔ No predictor values are missing.
#>
#> ✔ No issues with dimensionality.
#>
#> ✖ Strongly correlated, by Spearman rank, pairs of numerical values are:
#>
#> Bedrooms - AreaNet: 0.77;
#> Bedrooms - AreaGross: 0.77;
#> Bathrooms - AreaNet: 0.78;
#> Bathrooms - AreaGross: 0.78;
#> AreaNet - AreaGross: 1;
#>
#> ✖ Strongly correlated, by Crammer's V rank, pairs of categorical values are:
#> PropertyType - PropertySubType: 1;
#>
#> ✖ These obserwation migth be outliers due to their numerical columns values:
#> 145 146 196 44 5 51 57 58 59 60 61 62 63 64 69 75 76 77 78 ;
#>
#> ✖ Target data is not evenly distributed with quantile bins: 0.25 0.35 0.14 0.26
#>
#> ✖ Columns names suggest that some of them are IDs, removing them can improve the model.
#> Suspicious columns are: Id .
#>
#> ✖ Columns data suggest that some of them are IDs, removing them can improve the model.
#> Suspicious columns are: Id .
#>
#> -------------------- CHECK DATA REPORT END --------------------
#>
#> ✔ Data preprocessed.
#> ✔ Data split and balanced.
#> ✔ Correct formats prepared.
#> ✔ Models successfully trained.
#> ✔ Predicted successfully.
train_output$score_valid
#> no. name engine tuning mse r2 mae
#> 1 2 xgboost_model xgboost basic 143685453859 0.60640199 135871.5
#> 2 46 xgboost_bayes xgboost bayes_opt 144796997342 0.60335713 138757.1
#> 3 47 decision_tree_bayes decision_tree bayes_opt 167134585019 0.54216771 151383.7
#> 4 11 ranger_RS_7 ranger random_search 172355600654 0.52786576 146038.0
#> 5 12 ranger_RS_8 ranger random_search 175206764634 0.52005556 140453.4
#> 6 1 ranger_model ranger basic 176861131326 0.51552375 140467.5
#> 7 45 ranger_bayes ranger bayes_opt 177100208139 0.51486884 142764.0
#> 8 10 ranger_RS_6 ranger random_search 177696042483 0.51323667 150902.5
#> 9 8 ranger_RS_4 ranger random_search 178173236730 0.51192949 147778.7
#> 10 6 ranger_RS_2 ranger random_search 180738901681 0.50490136 143104.2
#> 11 4 lightgbm_model lightgbm basic 198600114778 0.45597408 163623.3
#> 12 43 lightgbm_RS_9 lightgbm random_search 198600114778 0.45597408 163623.3
#> 13 38 lightgbm_RS_4 lightgbm random_search 201251798549 0.44871032 169723.7
#> 14 39 lightgbm_RS_5 lightgbm random_search 201251798549 0.44871032 169723.7
#> 15 3 decision_tree_model decision_tree basic 201497538451 0.44803716 176132.1
#> 16 25 decision_tree_RS_1 decision_tree random_search 201497538451 0.44803716 176132.1
#> 17 26 decision_tree_RS_2 decision_tree random_search 201497538451 0.44803716 176132.1
#> 18 27 decision_tree_RS_3 decision_tree random_search 201497538451 0.44803716 176132.1
#> 19 28 decision_tree_RS_4 decision_tree random_search 201497538451 0.44803716 176132.1
#> 20 29 decision_tree_RS_5 decision_tree random_search 201497538451 0.44803716 176132.1
#> 21 30 decision_tree_RS_6 decision_tree random_search 201497538451 0.44803716 176132.1
#> 22 31 decision_tree_RS_7 decision_tree random_search 201497538451 0.44803716 176132.1
#> 23 32 decision_tree_RS_8 decision_tree random_search 201497538451 0.44803716 176132.1
#> 24 33 decision_tree_RS_9 decision_tree random_search 201497538451 0.44803716 176132.1
#> 25 34 decision_tree_RS_10 decision_tree random_search 201497538451 0.44803716 176132.1
#> 26 48 lightgbm_bayes lightgbm bayes_opt 201568659750 0.44784234 162111.1
#> 27 35 lightgbm_RS_1 lightgbm random_search 202121875717 0.44632691 170545.3
#> 28 14 ranger_RS_10 ranger random_search 213760186314 0.41444605 187644.4
#> 29 42 lightgbm_RS_8 lightgbm random_search 214805128850 0.41158364 163852.3
#> 30 44 lightgbm_RS_10 lightgbm random_search 214805128850 0.41158364 163852.3
#> 31 37 lightgbm_RS_3 lightgbm random_search 215251555308 0.41036074 165520.9
#> 32 41 lightgbm_RS_7 lightgbm random_search 215251555308 0.41036074 165520.9
#> 33 5 ranger_RS_1 ranger random_search 216042730457 0.40819347 190711.9
#> 34 9 ranger_RS_5 ranger random_search 221397579341 0.39352492 193152.2
#> 35 7 ranger_RS_3 ranger random_search 226399456894 0.37982326 187923.1
#> 36 40 lightgbm_RS_6 lightgbm random_search 228517310575 0.37402181 175919.8
#> 37 13 ranger_RS_9 ranger random_search 229673753119 0.37085397 195580.2
#> 38 15 xgboost_RS_1 xgboost random_search 251159464876 0.31199809 200641.5
#> 39 22 xgboost_RS_8 xgboost random_search 251159464876 0.31199809 200641.5
#> 40 16 xgboost_RS_2 xgboost random_search 251529223453 0.31098521 202108.0
#> 41 24 xgboost_RS_10 xgboost random_search 251529223453 0.31098521 202108.0
#> 42 20 xgboost_RS_6 xgboost random_search 256059009039 0.29857675 202524.9
#> 43 36 lightgbm_RS_2 lightgbm random_search 257857650556 0.29364973 211581.6
#> 44 17 xgboost_RS_3 xgboost random_search 346948222547 0.04960364 315084.6
#> 45 21 xgboost_RS_7 xgboost random_search 346948222547 0.04960364 315084.6
#> 46 19 xgboost_RS_5 xgboost random_search 349598730125 0.04234310 313834.7
#> 47 23 xgboost_RS_9 xgboost random_search 349598730125 0.04234310 313834.7
#> 48 18 xgboost_RS_4 xgboost random_search 556458657778 -0.52430895 486418.6