
Generalized Neural Network Trainer
Source:R/generalized-nn-fit.R, R/generalized-nn-fitds.R
gen-nn-train.Rdtrain_nn() is a generic function for training neural networks with a
user-defined architecture via nn_arch(). Dispatch is based on the class
of x.
Recommended workflow:
Define architecture with
nn_arch()(optional).Train with
train_nn().Predict with
predict.nn_fit().
All methods delegate to a shared implementation core after preprocessing.
When architecture = NULL, the model falls back to a plain feed-forward neural network
(nn_linear) architecture.
Usage
train_nn(x, ...)
# S3 method for class 'matrix'
train_nn(
x,
y,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
early_stopping = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
...
)
# S3 method for class 'data.frame'
train_nn(
x,
y,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
early_stopping = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
...
)
# S3 method for class 'formula'
train_nn(
x,
data,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
early_stopping = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
...
)
# Default S3 method
train_nn(x, ...)
# S3 method for class 'dataset'
train_nn(
x,
y = NULL,
hidden_neurons = NULL,
activations = NULL,
output_activation = NULL,
bias = TRUE,
arch = NULL,
architecture = NULL,
flatten_input = NULL,
epochs = 100,
batch_size = 32,
penalty = 0,
mixture = 0,
learn_rate = 0.001,
optimizer = "adam",
optimizer_args = list(),
loss = "mse",
validation_split = 0,
device = NULL,
verbose = FALSE,
cache_weights = FALSE,
n_classes = NULL,
...
)Arguments
- x
Dispatch is based on its current class:
matrix: used directly, no preprocessing applied.data.frame: preprocessed viahardhat::mold().ymay be a vector / factor / matrix of outcomes, or a formula describing the outcome–predictor relationship withinx.formula: combined withdataand preprocessed viahardhat::mold().dataset: atorchdataset object; batched viatorch::dataloader(). This is the recommended interface for sequence/time-series and image data.
- ...
Additional arguments passed to specific methods.
- y
Outcome data. Interpretation depends on the method:
For the
matrixanddata.framemethods: a numeric vector, factor, or matrix of outcomes.For the
data.framemethod only: alternatively a formula of the formoutcome ~ predictors, evaluated againstx.Ignored when
xis a formula (outcome is taken from the formula) or adataset(labels come from the dataset itself).
Integer vector specifying the number of neurons in each hidden layer, e.g.
c(128, 64)for two hidden layers. WhenNULLor missing, no hidden layers are used and the model reduces to a single linear mapping from inputs to outputs.- activations
Activation function specification(s) for the hidden layers. See
act_funs()for supported values. Recycled if a single value is given.- output_activation
Optional activation function for the output layer. Defaults to
NULL(no activation / linear output).- bias
Logical. Whether to include bias terms in each layer. Default
TRUE.- arch
Backward-compatible alias for
architecture. If both are supplied, they must be identical.- architecture
An
nn_arch()object specifying a custom architecture. DefaultNULL, which falls back to a standard feed-forward network.- early_stopping
An
early_stop()object specifying early stopping behaviour, orNULL(default) to disable early stopping. When supplied, training halts if the monitored metric does not improve by at leastmin_deltaforpatienceconsecutive epochs. Example:early_stopping = early_stop(patience = 10).- epochs
Positive integer. Number of full passes over the training data. Default
100.- batch_size
Positive integer. Number of samples per mini-batch. Default
32.- penalty
Non-negative numeric. L1/L2 regularization strength (lambda). Default
0(no regularization).- mixture
Numeric in [0, 1]. Elastic net mixing parameter:
0= pure ridge (L2),1= pure lasso (L1). Default0.- learn_rate
Positive numeric. Step size for the optimizer. Default
0.001.- optimizer
Character. Optimizer algorithm. One of
"adam"(default),"sgd", or"rmsprop".- optimizer_args
Named list of additional arguments forwarded to the optimizer constructor (e.g.
list(momentum = 0.9)for SGD). Defaultlist().- loss
Character or function. Loss function used during training. Built-in options:
"mse"(default),"mae","cross_entropy", or"bce". For classification tasks with a scalar label,"cross_entropy"is set automatically. Alternatively, supply a custom function or formula with signaturefunction(input, target)returning a scalartorch_tensor.- validation_split
Numeric in [0, 1). Proportion of training data held out for validation. Default
0(no validation set).- device
Character. Compute device:
"cpu","cuda", or"mps". DefaultNULL, which auto-detects the best available device.- verbose
Logical. If
TRUE, prints loss (and validation loss) at regular intervals during training. DefaultFALSE.- cache_weights
Logical. If
TRUE, stores a copy of the trained weight matrices in the returned object under$cached_weights. DefaultFALSE.- data
A data frame. Required when
xis a formula.- flatten_input
Logical or
NULL(dataset method only). Controls whether each batch/sample is flattened to 2D before entering the model.NULL(default) auto-selects:TRUEwhenarchitecture = NULL, otherwiseFALSE.- n_classes
Positive integer. Number of output classes. Required when
xis adatasetwith scalar (classification) labels; ignored otherwise.
Value
An object of class "nn_fit", or one of its subclasses:
c("nn_fit_tab", "nn_fit")— returned by thedata.frameandformulamethodsc("nn_fit_ds", "nn_fit")— returned by thedatasetmethod
All subclasses share a common structure. See Details for the list of components.
Details
The returned "nn_fit" object is a named list with the following components:
model— the trainedtorch::nn_moduleobjectfitted— fitted values on the training data (orNULLfor dataset fits)loss_history— numeric vector of per-epoch training loss, trimmed to actual epochs run (relevant when early stopping is active)val_loss_history— per-epoch validation loss, orNULLifvalidation_split = 0n_epochs— number of epochs actually trainedstopped_epoch— epoch at which early stopping triggered, orNAif training ran to completionhidden_neurons,activations,output_activation— architecture specpenalty,mixture— regularization settingsfeature_names,response_name— variable names (tabular methods only)no_x,no_y— number of input features and output nodesis_classification— logical flagy_levels,n_classes— class labels and count (classification only)device— device the model is oncached_weights— list of weight matrices, orNULLarch— thenn_archobject used, orNULL
Supported tasks and input formats
train_nn() is task-agnostic by design (no explicit task argument).
Task behavior is determined by your input interface and architecture:
Tabular data: use
matrix,data.frame, orformulamethods.Time series: use the
datasetmethod with per-item tensors shaped as[time, features](or your preferred convention) and a recurrent architecture viann_arch().Image classification: use the
datasetmethod with per-item tensors shaped for your first layer (commonly[channels, height, width]fortorch::nn_conv2d). If your source arrays are channel-last, reorder in the dataset or viainput_transform.
Matrix method
When x is supplied as a raw numeric matrix, no preprocessing is applied.
Data is passed directly to the shared train_nn_impl core.
Data frame method
When x is a data frame, y can be either a vector / factor / matrix of
outcomes, or a formula of the form outcome ~ predictors evaluated against
x. Preprocessing is handled by hardhat::mold().
Formula method
When x is a formula, data must be supplied as the data frame against
which the formula is evaluated. Preprocessing is handled by hardhat::mold().
Dataset method (train_nn.dataset())
Trains a neural network directly on a torch dataset object. Batching and
lazy loading are handled by torch::dataloader(), making this method
well-suited for large datasets that do not fit entirely in memory.
Architecture configuration follows the same contract as other train_nn()
methods via architecture = nn_arch(...) (or legacy arch = ...).
For non-tabular inputs (time series, images), set flatten_input = FALSE to
preserve tensor dimensions expected by recurrent or convolutional layers.
Labels are taken from the second element of each dataset item (i.e.
dataset[[i]][[2]]), so y is ignored. When the label is a scalar tensor,
a classification task is assumed and n_classes must be supplied. The loss
is automatically switched to "cross_entropy" in that case.
Fitted values are not cached in the returned object. Use
predict.nn_fit_ds() with newdata to obtain predictions after training.
Examples
# \donttest{
if (torch::torch_is_installed()) {
# Matrix method — no preprocessing
model = train_nn(
x = as.matrix(iris[, 2:4]),
y = iris$Sepal.Length,
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# Data frame method — y as a vector
model = train_nn(
x = iris[, 2:4],
y = iris$Sepal.Length,
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# Data frame method — y as a formula evaluated against x
model = train_nn(
x = iris,
y = Sepal.Length ~ . - Species,
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# Formula method — outcome derived from formula
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50
)
# No hidden layers — linear model
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
epochs = 50
)
# Architecture object (nn_arch -> train_nn)
mlp_arch = nn_arch(nn_name = "mlp_model")
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
architecture = mlp_arch,
epochs = 50
)
# Custom layer architecture
custom_linear = torch::nn_module(
"CustomLinear",
initialize = function(in_features, out_features, bias = TRUE) {
self$layer = torch::nn_linear(in_features, out_features, bias = bias)
},
forward = function(x) self$layer(x)
)
custom_arch = nn_arch(
nn_name = "custom_linear_mlp",
nn_layer = ~ custom_linear
)
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(16, 8),
activations = "relu",
architecture = custom_arch,
epochs = 50
)
# With early stopping
model = train_nn(
x = Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 200,
validation_split = 0.2,
early_stopping = early_stop(patience = 10)
)
}
# }
# \donttest{
if (torch::torch_is_installed()) {
# torch dataset method — labels come from the dataset itself
iris_cls_dataset = torch::dataset(
name = "iris_cls_dataset",
initialize = function(data = iris) {
self$x = torch::torch_tensor(
as.matrix(data[, 1:4]),
dtype = torch::torch_float32()
)
# Species is a factor; convert to integer (1-indexed -> keep as-is for cross_entropy)
self$y = torch::torch_tensor(
as.integer(data$Species),
dtype = torch::torch_long()
)
},
.getitem = function(i) {
list(self$x[i, ], self$y[i])
},
.length = function() {
self$x$size(1)
}
)()
model_nn_ds = train_nn(
x = iris_cls_dataset,
hidden_neurons = c(32, 10),
activations = "relu",
epochs = 80,
batch_size = 16,
learn_rate = 0.01,
n_classes = 3, # Iris dataset has only 3 species
validation_split = 0.2,
verbose = TRUE
)
pred_nn = predict(model_nn_ds, iris_cls_dataset)
class_preds = c("Setosa", "Versicolor", "Virginica")[predict(model_nn_ds, iris_cls_dataset)]
# Confusion Matrix
table(actual = iris$Species, pred = class_preds)
}
#> → Auto-detected classification task. Using cross_entropy loss.
#> ℹ Using device: cpu
#> Epoch 8/80 - Loss: 0.1559 - Val Loss: 0.2030
#> Epoch 16/80 - Loss: 0.0720 - Val Loss: 0.0962
#> Epoch 24/80 - Loss: 0.0572 - Val Loss: 0.1113
#> Epoch 32/80 - Loss: 0.0717 - Val Loss: 0.1640
#> Epoch 40/80 - Loss: 0.0465 - Val Loss: 0.1896
#> Epoch 48/80 - Loss: 0.0502 - Val Loss: 0.1241
#> Epoch 56/80 - Loss: 0.0690 - Val Loss: 0.0907
#> Epoch 64/80 - Loss: 0.0582 - Val Loss: 0.1276
#> Epoch 72/80 - Loss: 0.1211 - Val Loss: 0.1251
#> Epoch 80/80 - Loss: 0.0884 - Val Loss: 0.0856
#> pred
#> actual Setosa Versicolor Virginica
#> setosa 50 0 0
#> versicolor 0 49 1
#> virginica 0 2 48
# }