Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 69 additions & 45 deletions R/data.table.R
Original file line number Diff line number Diff line change
Expand Up @@ -1036,56 +1036,80 @@ replace_dot_alias = function(e) {
while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]]
# fix for R-Forge #5190. colsub[[1L]] gave error when it's a symbol.
# NB: _unary_ '-', not _binary_ '-' (#5826). Test for '!' length-2 should be redundant but low-cost & keeps code concise.
if (colsub %iscall% c("!", "-") && length(colsub) == 2L) {
negate_sdcols = TRUE
colsub = colsub[[2L]]
} else negate_sdcols = FALSE
# fix for #1216, make sure the parentheses are peeled from expr of the form (((1:4)))
while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]]
if (colsub %iscall% ':' && length(colsub)==3L && !is.call(colsub[[2L]]) && !is.call(colsub[[3L]])) {
# .SDcols is of the format a:b, ensure none of : arguments is a call data.table(V1=-1L, V2=-2L, V3=-3L)[,.SD,.SDcols=-V2:-V1] #4231
.SDcols = eval(colsub, setattr(as.list(seq_along(x)), 'names', names_x), parent.frame())
} else {
if (colsub %iscall% 'patterns') {
patterns_list_or_vector = eval_with_cols(colsub, names_x)
.SDcols = if (is.list(patterns_list_or_vector)) {
# each pattern gives a new filter condition, intersect the end result
Reduce(intersect, patterns_list_or_vector)
try_processSDcols = !(colsub %iscall% c("!", "-") && length(colsub) == 2L) && !(colsub %iscall% ':') && !(colsub %iscall% 'patterns')
if (try_processSDcols) {
sdcols_result = tryCatch({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to avoid try catch internally, better to validate input and raise error early if possible

.processSDcols(
SDcols_sub = colsub,
SDcols_missing = FALSE,
x = x,
jsub = jsub,
by = substitute(by),
enclos = parent.frame()
)
}, error = function(e) {
NULL
})
if (!is.null(sdcols_result)) {
ansvars = sdvars = sdcols_result$ansvars
ansvals = sdcols_result$ansvals
try_processSDcols = TRUE
} else {
patterns_list_or_vector
try_processSDcols = FALSE
}
}
if (!try_processSDcols) {
if (colsub %iscall% c("!", "-") && length(colsub) == 2L) {
negate_sdcols = TRUE
colsub = colsub[[2L]]
} else negate_sdcols = FALSE
# fix for #1216, make sure the parentheses are peeled from expr of the form (((1:4)))
while(colsub %iscall% "(") colsub = as.list(colsub)[[-1L]]
if (colsub %iscall% ':' && length(colsub)==3L && !is.call(colsub[[2L]]) && !is.call(colsub[[3L]])) {
# .SDcols is of the format a:b, ensure none of : arguments is a call data.table(V1=-1L, V2=-2L, V3=-3L)[,.SD,.SDcols=-V2:-V1] #4231
.SDcols = eval(colsub, setattr(as.list(seq_along(x)), 'names', names_x), parent.frame())
} else {
.SDcols = eval(colsub, parent.frame(), parent.frame())
# allow filtering via function in .SDcols, #3950
if (is.function(.SDcols)) {
.SDcols = lapply(x, .SDcols)
if (any(idx <- lengths(.SDcols) > 1L | vapply_1c(.SDcols, typeof) != 'logical' | vapply_1b(.SDcols, anyNA)))
stopf("When .SDcols is a function, it is applied to each column; the output of this function must be a non-missing boolean scalar signalling inclusion/exclusion of the column. However, these conditions were not met for: %s", brackify(names(x)[idx]))
.SDcols = unlist(.SDcols, use.names = FALSE)
if (colsub %iscall% 'patterns') {
patterns_list_or_vector = eval_with_cols(colsub, names_x)
.SDcols = if (is.list(patterns_list_or_vector)) {
# each pattern gives a new filter condition, intersect the end result
Reduce(intersect, patterns_list_or_vector)
} else {
patterns_list_or_vector
}
} else {
.SDcols = eval(colsub, parent.frame(), parent.frame())
# allow filtering via function in .SDcols, #3950
if (is.function(.SDcols)) {
.SDcols = lapply(x, .SDcols)
if (any(idx <- lengths(.SDcols) > 1L | vapply_1c(.SDcols, typeof) != 'logical' | vapply_1b(.SDcols, anyNA)))
stopf("When .SDcols is a function, it is applied to each column; the output of this function must be a non-missing boolean scalar signalling inclusion/exclusion of the column. However, these conditions were not met for: %s", brackify(names(x)[idx]))
.SDcols = unlist(.SDcols, use.names = FALSE)
}
}
}
}
if (anyNA(.SDcols))
stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols))))
if (is.logical(.SDcols)) {
if (length(.SDcols)!=length(x)) stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(x))
ansvals = which_(.SDcols, !negate_sdcols)
ansvars = sdvars = names_x[ansvals]
} else if (is.numeric(.SDcols)) {
.SDcols = as.integer(.SDcols)
# if .SDcols is numeric, use 'dupdiff' instead of 'setdiff'
if (length(unique(sign(.SDcols))) > 1L) stopf(".SDcols is numeric but has both +ve and -ve indices")
if (any(idx <- abs(.SDcols)>ncol(x) | abs(.SDcols)<1L))
stopf(".SDcols is numeric but out of bounds [1, %d] at: %s", ncol(x), brackify(which(idx)))
ansvars = sdvars = if (negate_sdcols) dupdiff(names_x[-.SDcols], bynames) else names_x[.SDcols]
ansvals = if (negate_sdcols) setdiff(seq_along(names(x)), c(.SDcols, which(names(x) %chin% bynames))) else .SDcols
} else {
if (!is.character(.SDcols)) stopf(".SDcols should be column numbers or names")
if (!all(idx <- .SDcols %chin% names_x))
stopf("Some items of .SDcols are not column names: %s", brackify(.SDcols[!idx]))
ansvars = sdvars = if (negate_sdcols) setdiff(names_x, c(.SDcols, bynames)) else .SDcols
# dups = FALSE here. DT[, .SD, .SDcols=c("x", "x")] again doesn't really help with which 'x' to keep (and if '-' which x to remove)
ansvals = chmatch(ansvars, names_x)
if (anyNA(.SDcols))
stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols))))
if (is.logical(.SDcols)) {
if (length(.SDcols)!=length(x)) stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(x))
ansvals = which_(.SDcols, !negate_sdcols)
ansvars = sdvars = names_x[ansvals]
} else if (is.numeric(.SDcols)) {
.SDcols = as.integer(.SDcols)
# if .SDcols is numeric, use 'dupdiff' instead of 'setdiff'
if (length(unique(sign(.SDcols))) > 1L) stopf(".SDcols is numeric but has both +ve and -ve indices")
if (any(idx <- abs(.SDcols)>ncol(x) | abs(.SDcols)<1L))
stopf(".SDcols is numeric but out of bounds [1, %d] at: %s", ncol(x), brackify(which(idx)))
ansvars = sdvars = if (negate_sdcols) dupdiff(names_x[-.SDcols], bynames) else names_x[.SDcols]
ansvals = if (negate_sdcols) setdiff(seq_along(names(x)), c(.SDcols, which(names(x) %chin% bynames))) else .SDcols
} else {
if (!is.character(.SDcols)) stopf(".SDcols should be column numbers or names")
if (!all(idx <- .SDcols %chin% names_x))
stopf("Some items of .SDcols are not column names: %s", brackify(.SDcols[!idx]))
ansvars = sdvars = if (negate_sdcols) setdiff(names_x, c(.SDcols, bynames)) else .SDcols
# dups = FALSE here. DT[, .SD, .SDcols=c("x", "x")] again doesn't really help with which 'x' to keep (and if '-' which x to remove)
ansvals = chmatch(ansvars, names_x)
}
}
}
# fix for long standing FR/bug, #495 and #484
Expand Down
55 changes: 55 additions & 0 deletions R/groupingsets.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,50 @@ rollup.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame())
}

# Helper function to process SDcols
.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper should be in data.table.R

names_x = names(x)
bysub = substitute(by)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bysub will contain the argument as passed to .processSDcols, not the argument as passed to its caller. Here, it's always either by or union(bynames, allbyvars). As a result, allbyvars is always empty:

library(data.table)
trace(data.table:::.processSDcols, 5, quote(stopifnot(!length(allbyvars))), print = FALSE)
example(groupingsets)
test.data.table()

This is probably not intended.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback. I have changedbysub = substitute(by) in the calling function, where by is still unevaluated to fix this error

allbyvars = intersect(all.vars(bysub), names_x)
usesSD = ".SD" %chin% all.vars(jsub)
if (!usesSD) {
return(NULL)
}
if (SDcols_missing) {
ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars))
ansvals = match(ansvars, names_x)
return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals))
}
sub.result = SDcols_sub
if (sub.result %iscall% "patterns") {
.SDcols = eval_with_cols(sub.result, names_x)
} else {
.SDcols = eval(sub.result, enclos)
}
if (anyNA(.SDcols))
stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols))))
if (is.character(.SDcols)) {
idx = .SDcols %chin% names_x
if (!all(idx))
stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx]))
ansvars = sdvars = .SDcols
ansvals = match(ansvars, names_x)
} else if (is.numeric(.SDcols)) {
ansvals = as.integer(.SDcols)
if (any(ansvals < 1L | ansvals > length(names_x)))
stopf(".SDcols contains indices out of bounds")
ansvars = sdvars = names_x[ansvals]
} else if (is.logical(.SDcols)) {
if (length(.SDcols) != length(names_x))
stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x))
ansvals = which(.SDcols)
ansvars = sdvars = names_x[ansvals]
} else {
stopf(".SDcols must be character, numeric, or logical")
}
list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)
}

cube = function(x, ...) {
UseMethod("cube")
}
Expand All @@ -29,6 +73,17 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
stopf("Argument 'id' must be a logical scalar.")
if (missing(j))
stopf("Argument 'j' is required")
# Implementing NSE in cube using the helper, .processSDcols
jj = substitute(j)
sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame())
if (is.null(sdcols_result)) {
.SDcols = NULL
} else {
ansvars = sdcols_result$ansvars
sdvars = sdcols_result$sdvars
ansvals = sdcols_result$ansvals
.SDcols = sdvars
}
# generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497
n = length(by)
keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k))))
Expand Down
40 changes: 40 additions & 0 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -11468,6 +11468,11 @@ sets = local({
by=c("color","year","status")
lapply(length(by):0, function(i) by[0:i])
})
test(1750.25,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test covers line 49, which handles valid numeric .SDcols indices. Test 1750.40 validates my bounds-checking modification but doesn't reach line 49 since it tests the error path. From a functional perspective, this is somewhat redundant with other existing tests, I added it to address the coverage gap. But I'm happy to remove it if you feel it doesn't add sufficient value

cube(copy(dt), j = lapply(.SD, mean), by = "color", .SDcols = 4, id=TRUE),
groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount",
sets = list("color", character(0)), id = TRUE)
)
test(1750.31,
rollup(dt, j = c(list(cnt=.N), lapply(.SD, sum)), by = c("color","year","status"), id=TRUE),
groupingsets(dt, j = c(list(cnt=.N), lapply(.SD, sum)), by = c("color","year","status"), sets=sets, id=TRUE)
Expand Down Expand Up @@ -11503,6 +11508,41 @@ test(1750.34,
character(0)),
id = TRUE)
)
test(1750.35,
cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")),
groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value",
sets = list(c("color","year","status"),
c("color","year"),
c("color","status"),
"color",
c("year","status"),
"year",
"status",
character(0)),
id = TRUE)
)
test(1750.36,
cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")),
error = "Some items of \\.SDcols are not column names"
)
test(1750.37,
cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)),
error = "\\.SDcols is a logical vector of length"
)
test(1750.38,
cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE),
groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount",
sets = list("color", character(0)),
id = TRUE)
)
test(1750.39,
cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")),
error = ".SDcols must be character, numeric, or logical"
)
test(1750.40,
cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)),
error = "out of bounds"
)
# grouping sets with integer64
if (test_bit64) {
set.seed(26)
Expand Down
Loading