params <-
list(eval = TRUE)

## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)
library(LBBNN)
has_torch <- requireNamespace("torch", quietly = TRUE) &&
            torch::torch_is_installed()

## ----eval = has_torch---------------------------------------------------------
i <- 1000
j <- 15
set.seed(42)
torch::torch_manual_seed(42)
X_nl <- matrix(runif(i * j, 0, 0.5), ncol = j)
y_nl <- (- 3 +  0.1 * log(abs(X_nl[, 1])) + 3 * cos(X_nl[, 2]) 
             + 2 * X_nl[, 3] * X_nl[, 4] + X_nl[, 5] - 
               X_nl[, 6] ** 2 + rnorm(i, sd = 0.1))
y <- c()
# change y to 0 and 1
y[y_nl > median(y_nl)] <- 1
y[y_nl <= median(y_nl)] <- 0
sim_data_nl <- as.data.frame(X_nl)
sim_data_nl <- cbind(sim_data_nl, y)
loaders_nl <- get_dataloaders(sim_data_nl, train_proportion = 0.9,
                           train_batch_size = 450, test_batch_size = 100,
                           standardize = FALSE)
train_loader_nl <- loaders_nl$train_loader
test_loader_nl  <- loaders_nl$test_loader

## ----eval = has_torch---------------------------------------------------------
problem <- "binary classification"
sizes <- c(j, 5, 5, 1) 
incl_priors <- c(0.5, 0.5, 0.5) 
stds <- c(1, 1, 1) 
incl_inits <- 'polarized'
device <- "cpu" 
model_nl <- lbbnn_net(problem_type = problem, sizes = sizes,
                              prior = incl_priors,
                              inclusion_inits = incl_inits, input_skip = TRUE,
                              std = stds, flow = TRUE, dims = c(10, 10, 10),
                              device = device, bias_inclusion_prob = FALSE)

## ----eval = has_torch---------------------------------------------------------
train_lbbnn(epochs = 20, LBBNN = model_nl,
            lr = 0.2, train_dl = train_loader_nl, device = device, verbose = FALSE)

validate_lbbnn(LBBNN = model_nl, num_samples = 2, test_dl = test_loader_nl,
               device = device)

## ----fig.width=6, fig.height=6, eval = has_torch------------------------------
plot(model_nl, type = "global", vertex_size = 7,
     edge_width = 0.4, label_size = 0.4)

