Makes probabilistic predictions for neural network with classification task.

make_preds_prob(model, test_ds, dev)

Arguments

model

neural network classification model

test_ds

dataset object from torch used for making test predictions

dev

device used for calculations (cpu or gpu)

Value

float (probabilistic) vector of predictions

Examples

if (FALSE) { dev <- "cpu" # presaved torch model model1 <- torch_load(system.file("extdata","preclf",package="fairpan")) # presaved output of preprocess function processed <- torch_load(system.file("extdata","processed",package="fairpan")) dsl <- dataset_loader(processed$train_x, processed$train_y, processed$test_x,processed$test_y, batch_size=5, dev=dev) preds1 <- make_preds_prob(model1,dsl$test_ds,dev) }