Skip to content

Commit 0ff1c11

Browse files
committed
Set less param values for hyperparam to avoid overcharging
1 parent 3e56666 commit 0ff1c11

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

R/machine_learning.R

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3637,7 +3637,7 @@ aggregate_results <- function(all_loaded) {
36373637
return(results)
36383638
}
36393639

3640-
### Helper function: get tuneGrid for models
3640+
### Helper function: get tuneGrid for models --->TO DO: might need to re-think how to choose how many values are gonna be evaluated
36413641
get_tune_grid = function(method, train_data){
36423642
set.seed(123)
36433643

@@ -3652,35 +3652,45 @@ get_tune_grid = function(method, train_data){
36523652
}
36533653
if(method == "rf"){
36543654
n_features <- ncol(train_data) - 1
3655-
return(data.frame(mtry = unique(round(seq(n_features * 0.2, n_features * 0.9, length.out = 5)))))
3655+
return(data.frame(mtry = unique(round(seq(n_features * 0.2, n_features * 0.9, length.out = 3)))))
36563656
}
36573657
if(method == "svmRadial"){
3658-
return(expand.grid(sigma = 0.01, C = c(0.25, 0.5, 1, 2, 4)))
3658+
# Typical small-to-moderate RBF widths + modest C range
3659+
return(expand.grid(
3660+
sigma = c(0.01, 0.05, 0.1),
3661+
C = c(0.5, 1, 2)
3662+
))
36593663
}
36603664
if(method == "treebag"){
36613665
return(data.frame(parameter = "none"))
36623666
}
36633667
if(method == "C5.0"){
3664-
return(expand.grid(trials = c(1, 5, 10), model = "tree", winnow = c(TRUE, FALSE)))
3668+
return(expand.grid(
3669+
trials = c(1, 5, 10),
3670+
model = "tree",
3671+
winnow = c(TRUE, FALSE)
3672+
))
36653673
}
36663674
if(method == "knn"){
3667-
return(expand.grid(k = c(3, 5, 7, 9, 11)))
3675+
# Odd ks to avoid ties; small-to-moderate neighborhood sizes
3676+
return(expand.grid(k = c(5, 7, 9)))
36683677
}
36693678
if(method == "rpart"){
3670-
return(expand.grid(cp = seq(0.001, 0.1, length = 10)))
3679+
# Coarse cp sweep across low/med/high regularization
3680+
return(expand.grid(cp = c(0.001, 0.01, 0.1)))
36713681
}
36723682
if(method == "svmLinear"){
3673-
return(expand.grid(C = c(0.25, 0.5, 1, 2, 4)))
3683+
return(expand.grid(C = c(0.5, 1, 2)))
36743684
}
36753685
if(method == "xgbTree"){
36763686
return(expand.grid(
3677-
nrounds = 100,
3678-
max_depth = c(3, 6, 9),
3679-
eta = c(0.01, 0.1, 0.3),
3680-
gamma = 0,
3681-
colsample_bytree = 0.8,
3682-
min_child_weight = 1,
3683-
subsample = 0.8
3687+
nrounds = c(100, 300, 500),
3688+
max_depth = c(3, 6, 9),
3689+
eta = c(0.01, 0.1, 0.3),
3690+
gamma = 0, # fixed default
3691+
colsample_bytree = 0.8, # fixed default
3692+
min_child_weight = 1, # fixed default
3693+
subsample = 0.8 # fixed default
36843694
))
36853695
}
36863696

0 commit comments

Comments
 (0)