Skip to contents

This wrapper function automatically initializes the model by adding all numerical features as spline base-learner. Categorical features are dummy encoded and inserted using another linear base-learners without intercept. The function boostSplines does also train the model.

The returned object is an object of the Compboost class. This object can be used for further analyses (see ?Compboost for details).

Usage

boostSplines(
  data,
  target,
  optimizer = NULL,
  loss = NULL,
  learning_rate = 0.05,
  iterations = 100,
  trace = -1,
  degree = 3,
  n_knots = 20,
  penalty = 2,
  df = 0,
  differences = 2,
  data_source = InMemoryData,
  oob_fraction = NULL,
  bin_root = 0,
  cache_type = "inverse",
  stop_args = NULL,
  df_cat = 1,
  stop_time = "microseconds",
  additional_risk_logs = list()
)

Arguments

data

(data.frame())
A data frame containing the data.

target

(character(1) | ResponseRegr | ResponseBinaryClassif)
Character value containing the target variable or response object. Note that the loss must match the data type of the target.

optimizer

(OptimizerCoordinateDescent | OptimizerCoordinateDescentLineSearch | OptimizerAGBM | OptimizerCosineAnnealing)
An initialized S4 optimizer object (requires to call Optimizer*.new(..). See the respective help page for further information.

loss

(LossQuadratic | LossBinomial | LossHuber | LossAbsolute | LossQuantile)
An initialized S4 loss object (requires to call Loss*$new(...)). See the respective help page for further information.

learning_rate

(numeric(1))
Learning rate to shrink the parameter in each step.

iterations

(integer(1))
Number of iterations that are trained. If iterations == 0, the untrained object is returned. This can be useful if other base learners (e.g. an interaction via a tensor base learner) are added.

trace

(integer(1))
Integer indicating how often a trace should be printed. Specifying trace = 10, then every 10th iteration is printed. If no trace should be printed set trace = 0. Default is -1 which means that in total 40 iterations are printed.

degree

(integer(1))cr Polynomial degree of the splines.

n_knots

(integer(1))
Number of equidistant "inner knots". The actual number of used knots does also depend on the polynomial degree.

penalty

(numeric(1))
Penalty term for p-splines. If the penalty equals 0, then ordinary b-splines are fitted. The higher the penalty, the higher the smoothness.

df

(numeric(1))
Degrees of freedom of the base learner(s).

differences

(integer(1))
Number of differences that are used for penalization. The higher the difference, the higher the smoothness.

data_source

(Data*)
Uninitialized Data* object which is used to store the data. At the moment just in memory training is supported.

oob_fraction

(numeric(1))
Fraction of how much data are used to track the out of bag risk.

bin_root

(integer(1))
The binning root to reduce the data to \(n^{1/\text{binroot}}\) data points (default bin_root = 1, which means no binning is applied). A value of bin_root = 2 is suggested for the best approximation error (cf. Wood et al. (2017) Generalized additive models for gigadata: modeling the UK black smoke network daily data).

cache_type

(character(1))
String to indicate what method should be used to estimate the parameter in each iteration. Default is cache_type = "cholesky" which computes the Cholesky decomposition, caches it, and reuses the matrix over and over again. The other option is to use cache_type = "inverse" which does the same but caches the inverse.

stop_args

(list(2))
List containing two elements patience and eps_for_break which can be set to use early stopping on the left out data from setting oob_fraction. If ! is.null(stop_args), early stopping is triggered.

df_cat

(numeric(1))
Degrees of freedom of the categorical base-learner.

stop_time

(character(1))
Unit of measured time.

additional_risk_logs

(list(Logger))
Additional logger passed to the Compboost object.

Value

A model of the Compboost class. This model is an R6 object which can be used for retraining, predicting, plotting, and anything described in ?Compboost.

Examples

mod = boostSplines(data = iris, target = "Sepal.Length", loss = LossQuadratic$new(),
  oob_fraction = 0.3)
#>   1/100   risk = 0.31  oob_risk = 0.32   time = 0   
#>   2/100   risk = 0.28  oob_risk = 0.29   time = 101   
#>   4/100   risk = 0.24  oob_risk = 0.26   time = 245   
#>   6/100   risk = 0.21  oob_risk = 0.23   time = 374   
#>   8/100   risk = 0.18  oob_risk = 0.2   time = 512   
#>  10/100   risk = 0.15  oob_risk = 0.18   time = 634   
#>  12/100   risk = 0.13  oob_risk = 0.17   time = 754   
#>  14/100   risk = 0.12  oob_risk = 0.15   time = 872   
#>  16/100   risk = 0.11  oob_risk = 0.14   time = 993   
#>  18/100   risk = 0.095  oob_risk = 0.13   time = 1114   
#>  20/100   risk = 0.087  oob_risk = 0.13   time = 1234   
#>  22/100   risk = 0.08  oob_risk = 0.12   time = 1355   
#>  24/100   risk = 0.074  oob_risk = 0.11   time = 1501   
#>  26/100   risk = 0.069  oob_risk = 0.11   time = 1639   
#>  28/100   risk = 0.066  oob_risk = 0.11   time = 1758   
#>  30/100   risk = 0.062  oob_risk = 0.11   time = 1932   
#>  32/100   risk = 0.06  oob_risk = 0.11   time = 2059   
#>  34/100   risk = 0.058  oob_risk = 0.1   time = 2198   
#>  36/100   risk = 0.056  oob_risk = 0.1   time = 2319   
#>  38/100   risk = 0.055  oob_risk = 0.1   time = 2440   
#>  40/100   risk = 0.053  oob_risk = 0.1   time = 2569   
#>  42/100   risk = 0.052  oob_risk = 0.1   time = 2705   
#>  44/100   risk = 0.051  oob_risk = 0.1   time = 2816   
#>  46/100   risk = 0.05  oob_risk = 0.1   time = 2925   
#>  48/100   risk = 0.049  oob_risk = 0.1   time = 3038   
#>  50/100   risk = 0.048  oob_risk = 0.1   time = 3150   
#>  52/100   risk = 0.048  oob_risk = 0.1   time = 3261   
#>  54/100   risk = 0.047  oob_risk = 0.1   time = 3370   
#>  56/100   risk = 0.046  oob_risk = 0.1   time = 3480   
#>  58/100   risk = 0.046  oob_risk = 0.1   time = 3591   
#>  60/100   risk = 0.045  oob_risk = 0.099   time = 3702   
#>  62/100   risk = 0.044  oob_risk = 0.099   time = 3814   
#>  64/100   risk = 0.044  oob_risk = 0.099   time = 3925   
#>  66/100   risk = 0.043  oob_risk = 0.099   time = 4044   
#>  68/100   risk = 0.043  oob_risk = 0.099   time = 4156   
#>  70/100   risk = 0.043  oob_risk = 0.099   time = 4267   
#>  72/100   risk = 0.042  oob_risk = 0.099   time = 4379   
#>  74/100   risk = 0.042  oob_risk = 0.098   time = 4490   
#>  76/100   risk = 0.041  oob_risk = 0.098   time = 4603   
#>  78/100   risk = 0.041  oob_risk = 0.098   time = 4718   
#>  80/100   risk = 0.041  oob_risk = 0.098   time = 4942   
#>  82/100   risk = 0.04  oob_risk = 0.098   time = 5077   
#>  84/100   risk = 0.04  oob_risk = 0.098   time = 5210   
#>  86/100   risk = 0.04  oob_risk = 0.098   time = 5352   
#>  88/100   risk = 0.039  oob_risk = 0.098   time = 5484   
#>  90/100   risk = 0.039  oob_risk = 0.097   time = 5629   
#>  92/100   risk = 0.039  oob_risk = 0.097   time = 5762   
#>  94/100   risk = 0.039  oob_risk = 0.097   time = 5893   
#>  96/100   risk = 0.038  oob_risk = 0.097   time = 6039   
#>  98/100   risk = 0.038  oob_risk = 0.097   time = 6153   
#> 100/100   risk = 0.038  oob_risk = 0.097   time = 6266   
#> 
#> 
#> Train 100 iterations in 0 Seconds.
#> Final risk based on the train set: 0.038
#> 
mod$getBaselearnerNames()
#> [1] "Sepal.Width_spline"  "Petal.Length_spline" "Petal.Width_spline" 
#> [4] "Species_ridge"      
mod$getEstimatedCoef()
#> Depricated, use `$getCoef()` instead.
#> $Petal.Length_spline
#>               [,1]
#>  [1,] -1.020038063
#>  [2,] -0.842700185
#>  [3,] -0.708210074
#>  [4,] -0.616528258
#>  [5,] -0.629992257
#>  [6,] -0.674377781
#>  [7,] -0.700732067
#>  [8,] -0.697215039
#>  [9,] -0.651967365
#> [10,] -0.537788304
#> [11,] -0.341686833
#> [12,] -0.177164903
#> [13,]  0.007505898
#> [14,]  0.290089474
#> [15,]  0.451293220
#> [16,]  0.400021292
#> [17,]  0.366344295
#> [18,]  0.483695287
#> [19,]  0.784027942
#> [20,]  1.123146019
#> [21,]  1.463578011
#> [22,]  1.656349917
#> [23,]  1.738770154
#> [24,]  1.806016019
#> attr(,"blclass")
#> [1] "Rcpp_BaselearnerPSpline"
#> 
#> $Petal.Width_spline
#>               [,1]
#>  [1,] -0.368363815
#>  [2,] -0.230727829
#>  [3,] -0.097569154
#>  [4,] -0.016578666
#>  [5,] -0.028837656
#>  [6,] -0.057857155
#>  [7,] -0.057668639
#>  [8,] -0.028703497
#>  [9,] -0.002385605
#> [10,] -0.010779152
#> [11,] -0.054068445
#> [12,] -0.040466735
#> [13,]  0.064127383
#> [14,]  0.136240247
#> [15,]  0.133221263
#> [16,]  0.102792184
#> [17,]  0.067773564
#> [18,]  0.110728014
#> [19,]  0.161049540
#> [20,]  0.160193965
#> [21,]  0.130363799
#> [22,]  0.046188478
#> [23,] -0.067301221
#> [24,] -0.181037323
#> attr(,"blclass")
#> [1] "Rcpp_BaselearnerPSpline"
#> 
#> $Sepal.Width_spline
#>               [,1]
#>  [1,] -0.391634041
#>  [2,] -0.234533599
#>  [3,] -0.086240363
#>  [4,]  0.009214706
#>  [5,]  0.007799400
#>  [6,] -0.065038926
#>  [7,] -0.104271259
#>  [8,] -0.096040595
#>  [9,] -0.054096040
#> [10,] -0.017548114
#> [11,] -0.012499394
#> [12,]  0.013963397
#> [13,] -0.016707110
#> [14,] -0.064197029
#> [15,] -0.059409731
#> [16,]  0.010539003
#> [17,]  0.075101623
#> [18,]  0.111945966
#> [19,]  0.157947620
#> [20,]  0.253521014
#> [21,]  0.357967518
#> [22,]  0.374894173
#> [23,]  0.371380710
#> [24,]  0.369954604
#> attr(,"blclass")
#> [1] "Rcpp_BaselearnerPSpline"
#> 
#> $offset
#> [1] 5.79619
#> 
table(mod$getSelectedBaselearner())
#> 
#> Petal.Length_spline  Petal.Width_spline  Sepal.Width_spline 
#>                  45                  26                  29 
mod$predict()
#>            [,1]
#>   [1,] 5.001037
#>   [2,] 4.971113
#>   [3,] 4.890050
#>   [4,] 5.003350
#>   [5,] 5.410319
#>   [6,] 4.975044
#>   [7,] 4.956201
#>   [8,] 4.891916
#>   [9,] 5.129464
#>  [10,] 5.001642
#>  [11,] 4.859679
#>  [12,] 4.715600
#>  [13,] 5.228497
#>  [14,] 5.077558
#>  [15,] 5.303721
#>  [16,] 5.262642
#>  [17,] 5.223762
#>  [18,] 5.012599
#>  [19,] 5.036910
#>  [20,] 4.939170
#>  [21,] 4.994930
#>  [22,] 5.029948
#>  [23,] 5.266600
#>  [24,] 5.342117
#>  [25,] 5.003350
#>  [26,] 4.842022
#>  [27,] 4.958629
#>  [28,] 4.942359
#>  [29,] 4.928705
#>  [30,] 4.975044
#>  [31,] 5.035150
#>  [32,] 4.890050
#>  [33,] 5.127204
#>  [34,] 5.317974
#>  [35,] 5.212719
#>  [36,] 4.932458
#>  [37,] 5.129464
#>  [38,] 4.910691
#>  [39,] 6.220816
#>  [40,] 6.336589
#>  [41,] 6.109429
#>  [42,] 6.280500
#>  [43,] 5.204318
#>  [44,] 6.182060
#>  [45,] 5.668333
#>  [46,] 5.142687
#>  [47,] 6.005284
#>  [48,] 5.688583
#>  [49,] 5.457387
#>  [50,] 6.140691
#>  [51,] 5.707504
#>  [52,] 6.247146
#>  [53,] 6.262478
#>  [54,] 5.694635
#>  [55,] 6.247862
#>  [56,] 6.144328
#>  [57,] 6.276506
#>  [58,] 6.296937
#>  [59,] 5.282725
#>  [60,] 5.495523
#>  [61,] 5.459322
#>  [62,] 5.537781
#>  [63,] 6.242001
#>  [64,] 6.223752
#>  [65,] 6.067748
#>  [66,] 5.802462
#>  [67,] 5.635231
#>  [68,] 5.917087
#>  [69,] 6.273111
#>  [70,] 5.819931
#>  [71,] 5.833788
#>  [72,] 5.873289
#>  [73,] 5.015609
#>  [74,] 5.769503
#>  [75,] 6.704795
#>  [76,] 6.870084
#>  [77,] 6.445517
#>  [78,] 6.740277
#>  [79,] 7.253793
#>  [80,] 6.581417
#>  [81,] 6.285423
#>  [82,] 6.232806
#>  [83,] 6.242707
#>  [84,] 6.179686
#>  [85,] 6.267134
#>  [86,] 7.795450
#>  [87,] 7.541891
#>  [88,] 6.312717
#>  [89,] 7.576167
#>  [90,] 6.578978
#>  [91,] 6.874760
#>  [92,] 6.719281
#>  [93,] 7.603128
#>  [94,] 6.493966
#>  [95,] 6.281089
#>  [96,] 6.372074
#>  [97,] 6.376361
#>  [98,] 6.374356
#>  [99,] 6.404666
#> [100,] 6.777900
#> [101,] 6.352684
#> [102,] 6.289097
#> [103,] 6.203257
#> [104,] 6.311141
#> [105,] 6.264772