This class defines a custom base learner factory by
passing R functions for instantiation, fitting, and predicting.
Format
S4 object.
Arguments
- data_source
(InMemoryData)
Uninitialized data object used to store the meta data. Note: At the moment, just in memory storing is supported, see?InMemorydatafor details.- instantiate_fun
(
function)Rfunction to transform the source data.- train_fun
(
function)Rfunction to train the base learner on the target data.- predict_fun
(
function)Rfunction to predict on the object returned bytrain_fun.- param_fun
(
function)Rfunction to extract the parameter of the object returned bytrain.
Usage
BaselearnerCustom$new(data_source, list(instantiate_fun,
train_fun, predict_fun, param_fun))Details
The function must have the following structure:
instantiateData(X) { ... return (X_trafo) } With a matrix argument
X and a matrix as return object.
train(y, X) { ... return (SEXP) } With a vector argument y
and a matrix argument X. The target data is used in X while
y contains the response. The function can return any R
object which is stored within a SEXP.
predict(model, newdata) { ... return (prediction) } The returned
object of the train function is passed to the model
argument while newdata contains a new matrix used for predicting.
extractParameter() { ... return (parameters) } Again, model
contains the object returned by train. The returned object must be
a matrix containing the estimated parameter. If no parameter should be
estimated one can return NA.
For an example see the Examples.
Methods
$summarizeFactory():() -> ()$transfromData(newdata):list(InMemoryData) -> matrix()$getMeta():() -> list()
Inherited methods from Baselearner
$getData():() -> matrix()$getDF():() -> integer()$getPenalty():() -> numeric()$getPenaltyMat():() -> matrix()$getFeatureName():() -> character()$getModelName():() -> character()$getBaselearnerId():() -> character()
Examples
# Sample data:
data_mat = cbind(1, 1:10)
y = 2 + 3 * 1:10
# Create new data object:
data_source = InMemoryData$new(data_mat, "my_data_name")
instantiateDataFun = function (X) {
return(X)
}
# Ordinary least squares estimator:
trainFun = function (y, X) {
return(solve(t(X) %*% X) %*% t(X) %*% y)
}
predictFun = function (model, newdata) {
return(as.matrix(newdata %*% model))
}
extractParameter = function (model) {
return(as.matrix(model))
}
# Create new custom linear base learner factory:
custom_lin_factory = BaselearnerCustom$new(data_source,
list(instantiate_fun = instantiateDataFun, train_fun = trainFun,
predict_fun = predictFun, param_fun = extractParameter))
# Get the transformed data:
custom_lin_factory$getData()
#> [,1] [,2]
#> [1,] 1 1
#> [2,] 1 2
#> [3,] 1 3
#> [4,] 1 4
#> [5,] 1 5
#> [6,] 1 6
#> [7,] 1 7
#> [8,] 1 8
#> [9,] 1 9
#> [10,] 1 10
# Summarize factory:
custom_lin_factory$summarizeFactory()
#> Custom base learner Factory:
#> - Name of the used data: my_data_name
#> - Factory creates the following base learner: custom
