diff --git a/R/data.table.R b/R/data.table.R index a989538b1..ec8847f06 100644 --- a/R/data.table.R +++ b/R/data.table.R @@ -521,6 +521,50 @@ replace_dot_alias = function(e) { list(GForce=GForce, jsub=jsub, jvnames=jvnames) } +# Helper function to process SDcols +.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) { + names_x = names(x) + bysub = substitute(by) + allbyvars = intersect(all.vars(bysub), names_x) + usesSD = ".SD" %chin% all.vars(jsub) + if (!usesSD) { + return(NULL) + } + if (SDcols_missing) { + ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars)) + ansvals = match(ansvars, names_x) + return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)) + } + sub.result = SDcols_sub + if (sub.result %iscall% "patterns") { + .SDcols = eval_with_cols(sub.result, names_x) + } else { + .SDcols = eval(sub.result, enclos) + } + if (anyNA(.SDcols)) + stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) + if (is.character(.SDcols)) { + idx = .SDcols %chin% names_x + if (!all(idx)) + stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx])) + ansvars = sdvars = .SDcols + ansvals = match(ansvars, names_x) + } else if (is.numeric(.SDcols)) { + ansvals = as.integer(.SDcols) + if (any(ansvals < 1L | ansvals > length(names_x))) + stopf(".SDcols contains indices out of bounds") + ansvars = sdvars = names_x[ansvals] + } else if (is.logical(.SDcols)) { + if (length(.SDcols) != length(names_x)) + stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x)) + ansvals = which(.SDcols) + ansvars = sdvars = names_x[ansvals] + } else { + stopf(".SDcols must be character, numeric, or logical") + } + list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals) +} + "[.data.table" = function(x, i, j, by, keyby, with=TRUE, nomatch=NA, mult="all", roll=FALSE, rollends=if (roll=="nearest") c(TRUE,TRUE) else if (roll>=0.0) c(FALSE,TRUE) else c(TRUE,FALSE), which=FALSE, .SDcols, verbose=getOption("datatable.verbose"), allow.cartesian=getOption("datatable.allow.cartesian"), drop=NULL, on=NULL, env=NULL, showProgress=getOption("datatable.showProgress", interactive())) { # ..selfcount <<- ..selfcount+1 # in dev, we check no self calls, each of which doubles overhead, or could diff --git a/R/groupingsets.R b/R/groupingsets.R index 885a64830..e31284831 100644 --- a/R/groupingsets.R +++ b/R/groupingsets.R @@ -29,6 +29,17 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { stopf("Argument 'id' must be a logical scalar.") if (missing(j)) stopf("Argument 'j' is required") + # Implementing NSE in cube using the helper, .processSDcols + jj = substitute(j) + sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame()) + if (is.null(sdcols_result)) { + .SDcols = NULL + } else { + ansvars = sdcols_result$ansvars + sdvars = sdcols_result$sdvars + ansvals = sdcols_result$ansvals + .SDcols = sdvars + } # generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 n = length(by) keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k)))) diff --git a/inst/tests/tests.Rraw b/inst/tests/tests.Rraw index 508bf6aa0..a8c7df371 100644 --- a/inst/tests/tests.Rraw +++ b/inst/tests/tests.Rraw @@ -11102,6 +11102,43 @@ test(1750.34, character(0)), id = TRUE) ) +test(1750.35, + cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")), + groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value", + sets = list(c("color","year","status"), + c("color","year"), + c("color","status"), + "color", + c("year","status"), + "year", + "status", + character(0)), + id = TRUE) +) +test(1750.36, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")), + error = "Some items of \\.SDcols are not column names" +) + +test(1750.37, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)), + error = "\\.SDcols is a logical vector of length" +) + +test(1750.38, +cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE), + groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount", + sets = list("color", character(0)), + id = TRUE) +) +test(1750.39, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")), + error = ".SDcols must be character, numeric, or logical" +) +test(1750.40, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)), + error = "out of bounds" +) # grouping sets with integer64 if (test_bit64) { set.seed(26) @@ -11147,6 +11184,16 @@ if (test_bit64) { } # end Grouping Sets +# extra cube tests +test(1750.49, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(NA_character_, "amount")), + error = "\\.SDcols missing at the following indices: \\[1\\]" +) +test(1750.50, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(4L, 5L)), + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("amount", "value")) +) + # for completeness, added test for NA problem to close #1837. DT = data.table(x=NA) test(1751.1, capture.output(fwrite(DT, verbose=FALSE)), c("x",""))