Makes binary predictions for neural network with classification task.
make_preds(model, test_ds, dev)
model | net, nn_module, neural network classification model |
---|---|
test_ds |
|
dev | device used for calculations (cpu or gpu) |
integer (binary) vector of predictions
if (FALSE) { dev <- "cpu" # presaved torch model model1 <- torch_load(system.file("extdata","preclf",package="fairpan")) # presaved output of preprocess function processed <- torch_load(system.file("extdata","processed",package="fairpan")) dsl <- dataset_loader(processed$train_x, processed$train_y, processed$test_x, processed$test_y, batch_size=5, dev=dev) preds1 <- make_preds(model1,dsl$test_ds,dev) }