Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions src/parse_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include "duckdb/parser/tableref/joinref.hpp"
#include "duckdb/parser/tableref/subqueryref.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"
#include "duckdb/parser/expression/subquery_expression.hpp"
#include "duckdb/parser/parsed_expression_iterator.hpp"
#include "duckdb/parser/result_modifier.hpp"

namespace duckdb {

Expand Down Expand Up @@ -73,6 +76,12 @@ static unique_ptr<GlobalTableFunctionState> ParseTablesInit(ClientContext &conte
return make_uniq<ParseTablesState>();
}

static void ExtractTablesFromExpression(
const duckdb::ParsedExpression &expr,
std::vector<TableRefResult> &results,
const duckdb::CommonTableExpressionMap *cte_map = nullptr
);

static void ExtractTablesFromRef(
const duckdb::TableRef &ref,
std::vector<TableRefResult> &results,
Expand All @@ -89,7 +98,7 @@ static void ExtractTablesFromRef(

if (cte_map && cte_map->map.find(base.table_name) != cte_map->map.end()) {
context_label = TableContext::FromCTE;
} else if (is_top_level) {
} else if (is_top_level && context != TableContext::Subquery) {
context_label = TableContext::From;
}

Expand All @@ -104,12 +113,15 @@ static void ExtractTablesFromRef(
auto &join = (JoinRef &)ref;
ExtractTablesFromRef(*join.left, results, TableContext::JoinLeft, is_top_level, cte_map);
ExtractTablesFromRef(*join.right, results, TableContext::JoinRight, false, cte_map);
if (join.condition) {
ExtractTablesFromExpression(*join.condition, results, cte_map);
}
break;
}
case TableReferenceType::SUBQUERY: {
auto &subquery = (SubqueryRef &)ref;
if (subquery.subquery && subquery.subquery->node) {
ExtractTablesFromQueryNode(*subquery.subquery->node, results, TableContext::Subquery, cte_map);
ExtractTablesFromQueryNode(*subquery.subquery->node, results, TableContext::From, cte_map);
}
break;
}
Expand All @@ -118,6 +130,24 @@ static void ExtractTablesFromRef(
}
}

static void ExtractTablesFromExpression(
const duckdb::ParsedExpression &expr,
std::vector<TableRefResult> &results,
const duckdb::CommonTableExpressionMap *cte_map
) {
using namespace duckdb;

if (expr.expression_class == ExpressionClass::SUBQUERY) {
auto &subquery_expr = (SubqueryExpression &)expr;
if (subquery_expr.subquery && subquery_expr.subquery->node) {
ExtractTablesFromQueryNode(*subquery_expr.subquery->node, results, TableContext::Subquery, cte_map);
}
}

ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) {
ExtractTablesFromExpression(child, results, cte_map);
});
}

static void ExtractTablesFromQueryNode(
const duckdb::QueryNode &node,
Expand All @@ -144,7 +174,42 @@ static void ExtractTablesFromQueryNode(
if (select_node.from_table) {
ExtractTablesFromRef(*select_node.from_table, results, context, true, &select_node.cte_map);
}
}

for (const auto &expr : select_node.select_list) {
if (expr) {
ExtractTablesFromExpression(*expr, results, &select_node.cte_map);
}
}

if (select_node.where_clause) {
ExtractTablesFromExpression(*select_node.where_clause, results, &select_node.cte_map);
}

for (const auto &expr : select_node.groups.group_expressions) {
if (expr) {
ExtractTablesFromExpression(*expr, results, &select_node.cte_map);
}
}

if (select_node.having) {
ExtractTablesFromExpression(*select_node.having, results, &select_node.cte_map);
}

if (select_node.qualify) {
ExtractTablesFromExpression(*select_node.qualify, results, &select_node.cte_map);
}

for (const auto &modifier : select_node.modifiers) {
if (modifier->type == ResultModifierType::ORDER_MODIFIER) {
auto &order_modifier = (OrderModifier &)*modifier;
for (const auto &order : order_modifier.orders) {
if (order.expression) {
ExtractTablesFromExpression(*order.expression, results, &select_node.cte_map);
}
}
}
}
}
// additional step necessary for duckdb v1.4.0: unwrap CTE node
else if (node.type == QueryNodeType::CTE_NODE) {
auto &cte_node = (CTENode &)node;
Expand Down
20 changes: 19 additions & 1 deletion test/sql/parse_tools/scalar_functions/parse_tables.test
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,22 @@ select parse_tables('SELECT 1;');
query I
SELECT parse_tables('SELECT * FROM WHERE');
----
[]
[]

# WHERE clause subquery
query I
SELECT parse_tables('SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)');
----
[{'schema': main, 'table': users, 'context': from}, {'schema': main, 'table': orders, 'context': subquery}]

# nested subqueries
query I
SELECT parse_tables('SELECT * FROM a WHERE x IN (SELECT y FROM b WHERE z IN (SELECT w FROM c))');
----
[{'schema': main, 'table': a, 'context': from}, {'schema': main, 'table': b, 'context': subquery}, {'schema': main, 'table': c, 'context': subquery}]

# multiple contexts with subqueries
query I
SELECT parse_tables('SELECT (SELECT max(x) FROM b) FROM a WHERE y IN (SELECT z FROM c)');
----
[{'schema': main, 'table': a, 'context': from}, {'schema': main, 'table': b, 'context': subquery}, {'schema': main, 'table': c, 'context': subquery}]
53 changes: 52 additions & 1 deletion test/sql/parse_tools/table_functions/parse_tables.test
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,55 @@ SELECT * FROM parse_tables('SELECT 1;');
# malformed SQL should not error
query III
SELECT * FROM parse_tables('SELECT * FROM WHERE');
----
----

# WHERE clause subquery
query III
SELECT * FROM parse_tables('SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)');
----
main users from
main orders subquery

# WHERE clause with schema-qualified subquery
query III
SELECT * FROM parse_tables('SELECT * FROM schema1.users WHERE id IN (SELECT user_id FROM schema2.orders)');
----
schema1 users from
schema2 orders subquery

# nested subqueries in WHERE
query III
SELECT * FROM parse_tables('SELECT * FROM a WHERE x IN (SELECT y FROM b WHERE z IN (SELECT w FROM c))');
----
main a from
main b subquery
main c subquery

# SELECT list subquery (scalar subquery)
query III
SELECT * FROM parse_tables('SELECT id, (SELECT max(price) FROM products) AS max_price FROM orders');
----
main orders from
main products subquery

# HAVING clause subquery
query III
SELECT * FROM parse_tables('SELECT dept FROM employees GROUP BY dept HAVING count(*) > (SELECT avg(cnt) FROM dept_sizes)');
----
main employees from
main dept_sizes subquery

# JOIN condition subquery
query III
SELECT * FROM parse_tables('SELECT * FROM a JOIN b ON a.id = b.id AND b.status IN (SELECT code FROM statuses)');
----
main a from
main b join_right
main statuses subquery

# ORDER BY clause subquery
query III
SELECT * FROM parse_tables('SELECT * FROM users ORDER BY (SELECT max(score) FROM rankings WHERE user_id = users.id)');
----
main users from
main rankings subquery