
Predict from a trained neural network
Source:R/generalized-nn-fit.R, R/generalized-nn-fitds.R
gen-nn-predict.RdGenerate predictions from an "nn_fit" object produced by train_nn().
Three S3 methods are registered:
predict.nn_fit()— base method formatrix-trained models.predict.nn_fit_tab()— extends the base method for tabular fits; runs new data throughhardhat::forge()before predicting.predict.nn_fit_ds()— extends the base method for torchdatasetfits.
Usage
# S3 method for class 'nn_fit'
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
# S3 method for class 'nn_fit_tab'
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)
# S3 method for class 'nn_fit_ds'
predict(object, newdata = NULL, new_data = NULL, type = "response", ...)Arguments
- object
A fitted model object returned by
train_nn().- newdata
New predictor data. Accepted forms depend on the method:
predict.nn_fit(): a numericmatrixor coercible object.predict.nn_fit_tab(): adata.framewith the same columns used during training; preprocessing is applied automatically viahardhat::forge().predict.nn_fit_ds(): atorchdataset, numericarray,matrix, ordata.frame. IfNULL, the cached fitted values from training are returned (not available fortype = "prob").
- new_data
Legacy alias for
newdata. Retained for compatibility.- type
Character. Output type:
"response"(default): predicted class labels (factor) for classification, or a numeric vector / matrix for regression."prob": a numeric matrix of class probabilities (classification only).
- ...
Currently unused; reserved for future extensions.
Value
Regression: a numeric vector (single output) or matrix (multiple outputs).
Classification,
type = "response": a factor with levels matching those seen during training.Classification,
type = "prob": a numeric matrix with one column per class, columns named by class label.