-
Notifications
You must be signed in to change notification settings - Fork 1k
Implementing NSE in cube #7543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing NSE in cube #7543
Changes from all commits
388ef8c
3e96dfc
e1eb87a
b6adef9
47bb2c3
33572b5
e9876cb
2a15cb9
0ae97fa
7215a4c
2b7b7ff
431ca81
3bec96c
c16f64d
4119cee
f792b15
97b4536
1595595
25fbd53
32d078d
b2e6171
1e1225b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,50 @@ rollup.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { | |
| groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame()) | ||
| } | ||
|
|
||
| # Helper function to process SDcols | ||
| .processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This helper should be in data.table.R |
||
| names_x = names(x) | ||
| bysub = substitute(by) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is probably not intended.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the feedback. I have changed |
||
| allbyvars = intersect(all.vars(bysub), names_x) | ||
| usesSD = ".SD" %chin% all.vars(jsub) | ||
| if (!usesSD) { | ||
| return(NULL) | ||
| } | ||
| if (SDcols_missing) { | ||
| ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars)) | ||
| ansvals = match(ansvars, names_x) | ||
| return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)) | ||
| } | ||
| sub.result = SDcols_sub | ||
| if (sub.result %iscall% "patterns") { | ||
| .SDcols = eval_with_cols(sub.result, names_x) | ||
| } else { | ||
| .SDcols = eval(sub.result, enclos) | ||
| } | ||
| if (anyNA(.SDcols)) | ||
| stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) | ||
| if (is.character(.SDcols)) { | ||
| idx = .SDcols %chin% names_x | ||
| if (!all(idx)) | ||
| stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx])) | ||
| ansvars = sdvars = .SDcols | ||
| ansvals = match(ansvars, names_x) | ||
| } else if (is.numeric(.SDcols)) { | ||
| ansvals = as.integer(.SDcols) | ||
| if (any(ansvals < 1L | ansvals > length(names_x))) | ||
| stopf(".SDcols contains indices out of bounds") | ||
| ansvars = sdvars = names_x[ansvals] | ||
| } else if (is.logical(.SDcols)) { | ||
| if (length(.SDcols) != length(names_x)) | ||
| stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x)) | ||
| ansvals = which(.SDcols) | ||
| ansvars = sdvars = names_x[ansvals] | ||
| } else { | ||
| stopf(".SDcols must be character, numeric, or logical") | ||
| } | ||
| list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals) | ||
| } | ||
|
|
||
| cube = function(x, ...) { | ||
| UseMethod("cube") | ||
| } | ||
|
|
@@ -29,6 +73,17 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { | |
| stopf("Argument 'id' must be a logical scalar.") | ||
| if (missing(j)) | ||
| stopf("Argument 'j' is required") | ||
| # Implementing NSE in cube using the helper, .processSDcols | ||
| jj = substitute(j) | ||
| sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame()) | ||
| if (is.null(sdcols_result)) { | ||
| .SDcols = NULL | ||
| } else { | ||
| ansvars = sdcols_result$ansvars | ||
| sdvars = sdcols_result$sdvars | ||
| ansvals = sdcols_result$ansvals | ||
| .SDcols = sdvars | ||
| } | ||
| # generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 | ||
| n = length(by) | ||
| keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k)))) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11468,6 +11468,11 @@ sets = local({ | |
| by=c("color","year","status") | ||
| lapply(length(by):0, function(i) by[0:i]) | ||
| }) | ||
| test(1750.25, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of this test?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test covers line 49, which handles valid numeric .SDcols indices. Test 1750.40 validates my bounds-checking modification but doesn't reach line 49 since it tests the error path. From a functional perspective, this is somewhat redundant with other existing tests, I added it to address the coverage gap. But I'm happy to remove it if you feel it doesn't add sufficient value |
||
| cube(copy(dt), j = lapply(.SD, mean), by = "color", .SDcols = 4, id=TRUE), | ||
| groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount", | ||
| sets = list("color", character(0)), id = TRUE) | ||
| ) | ||
| test(1750.31, | ||
| rollup(dt, j = c(list(cnt=.N), lapply(.SD, sum)), by = c("color","year","status"), id=TRUE), | ||
| groupingsets(dt, j = c(list(cnt=.N), lapply(.SD, sum)), by = c("color","year","status"), sets=sets, id=TRUE) | ||
|
|
@@ -11503,6 +11508,41 @@ test(1750.34, | |
| character(0)), | ||
| id = TRUE) | ||
| ) | ||
| test(1750.35, | ||
| cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")), | ||
| groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value", | ||
| sets = list(c("color","year","status"), | ||
| c("color","year"), | ||
| c("color","status"), | ||
| "color", | ||
| c("year","status"), | ||
| "year", | ||
| "status", | ||
| character(0)), | ||
| id = TRUE) | ||
| ) | ||
| test(1750.36, | ||
| cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")), | ||
| error = "Some items of \\.SDcols are not column names" | ||
| ) | ||
| test(1750.37, | ||
| cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)), | ||
| error = "\\.SDcols is a logical vector of length" | ||
| ) | ||
| test(1750.38, | ||
| cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE), | ||
| groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount", | ||
| sets = list("color", character(0)), | ||
| id = TRUE) | ||
| ) | ||
| test(1750.39, | ||
| cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")), | ||
| error = ".SDcols must be character, numeric, or logical" | ||
| ) | ||
| test(1750.40, | ||
| cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)), | ||
| error = "out of bounds" | ||
| ) | ||
| # grouping sets with integer64 | ||
| if (test_bit64) { | ||
| set.seed(26) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to avoid try catch internally, better to validate input and raise error early if possible