diff --git a/src/parse_tables.cpp b/src/parse_tables.cpp index d902364..f7d90c3 100644 --- a/src/parse_tables.cpp +++ b/src/parse_tables.cpp @@ -11,6 +11,9 @@ #include "duckdb/parser/tableref/joinref.hpp" #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/parser/expression/subquery_expression.hpp" +#include "duckdb/parser/parsed_expression_iterator.hpp" +#include "duckdb/parser/result_modifier.hpp" namespace duckdb { @@ -73,6 +76,12 @@ static unique_ptr ParseTablesInit(ClientContext &conte return make_uniq(); } +static void ExtractTablesFromExpression( + const duckdb::ParsedExpression &expr, + std::vector &results, + const duckdb::CommonTableExpressionMap *cte_map = nullptr +); + static void ExtractTablesFromRef( const duckdb::TableRef &ref, std::vector &results, @@ -89,7 +98,7 @@ static void ExtractTablesFromRef( if (cte_map && cte_map->map.find(base.table_name) != cte_map->map.end()) { context_label = TableContext::FromCTE; - } else if (is_top_level) { + } else if (is_top_level && context != TableContext::Subquery) { context_label = TableContext::From; } @@ -104,12 +113,15 @@ static void ExtractTablesFromRef( auto &join = (JoinRef &)ref; ExtractTablesFromRef(*join.left, results, TableContext::JoinLeft, is_top_level, cte_map); ExtractTablesFromRef(*join.right, results, TableContext::JoinRight, false, cte_map); + if (join.condition) { + ExtractTablesFromExpression(*join.condition, results, cte_map); + } break; } case TableReferenceType::SUBQUERY: { auto &subquery = (SubqueryRef &)ref; if (subquery.subquery && subquery.subquery->node) { - ExtractTablesFromQueryNode(*subquery.subquery->node, results, TableContext::Subquery, cte_map); + ExtractTablesFromQueryNode(*subquery.subquery->node, results, TableContext::From, cte_map); } break; } @@ -118,6 +130,24 @@ static void ExtractTablesFromRef( } } +static void ExtractTablesFromExpression( + const duckdb::ParsedExpression &expr, + std::vector &results, + const duckdb::CommonTableExpressionMap *cte_map +) { + using namespace duckdb; + + if (expr.expression_class == ExpressionClass::SUBQUERY) { + auto &subquery_expr = (SubqueryExpression &)expr; + if (subquery_expr.subquery && subquery_expr.subquery->node) { + ExtractTablesFromQueryNode(*subquery_expr.subquery->node, results, TableContext::Subquery, cte_map); + } + } + + ParsedExpressionIterator::EnumerateChildren(expr, [&](const ParsedExpression &child) { + ExtractTablesFromExpression(child, results, cte_map); + }); +} static void ExtractTablesFromQueryNode( const duckdb::QueryNode &node, @@ -144,7 +174,42 @@ static void ExtractTablesFromQueryNode( if (select_node.from_table) { ExtractTablesFromRef(*select_node.from_table, results, context, true, &select_node.cte_map); } - } + + for (const auto &expr : select_node.select_list) { + if (expr) { + ExtractTablesFromExpression(*expr, results, &select_node.cte_map); + } + } + + if (select_node.where_clause) { + ExtractTablesFromExpression(*select_node.where_clause, results, &select_node.cte_map); + } + + for (const auto &expr : select_node.groups.group_expressions) { + if (expr) { + ExtractTablesFromExpression(*expr, results, &select_node.cte_map); + } + } + + if (select_node.having) { + ExtractTablesFromExpression(*select_node.having, results, &select_node.cte_map); + } + + if (select_node.qualify) { + ExtractTablesFromExpression(*select_node.qualify, results, &select_node.cte_map); + } + + for (const auto &modifier : select_node.modifiers) { + if (modifier->type == ResultModifierType::ORDER_MODIFIER) { + auto &order_modifier = (OrderModifier &)*modifier; + for (const auto &order : order_modifier.orders) { + if (order.expression) { + ExtractTablesFromExpression(*order.expression, results, &select_node.cte_map); + } + } + } + } + } // additional step necessary for duckdb v1.4.0: unwrap CTE node else if (node.type == QueryNodeType::CTE_NODE) { auto &cte_node = (CTENode &)node; diff --git a/test/sql/parse_tools/scalar_functions/parse_tables.test b/test/sql/parse_tools/scalar_functions/parse_tables.test index e825611..f09ee08 100644 --- a/test/sql/parse_tools/scalar_functions/parse_tables.test +++ b/test/sql/parse_tools/scalar_functions/parse_tables.test @@ -73,4 +73,22 @@ select parse_tables('SELECT 1;'); query I SELECT parse_tables('SELECT * FROM WHERE'); ---- -[] \ No newline at end of file +[] + +# WHERE clause subquery +query I +SELECT parse_tables('SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)'); +---- +[{'schema': main, 'table': users, 'context': from}, {'schema': main, 'table': orders, 'context': subquery}] + +# nested subqueries +query I +SELECT parse_tables('SELECT * FROM a WHERE x IN (SELECT y FROM b WHERE z IN (SELECT w FROM c))'); +---- +[{'schema': main, 'table': a, 'context': from}, {'schema': main, 'table': b, 'context': subquery}, {'schema': main, 'table': c, 'context': subquery}] + +# multiple contexts with subqueries +query I +SELECT parse_tables('SELECT (SELECT max(x) FROM b) FROM a WHERE y IN (SELECT z FROM c)'); +---- +[{'schema': main, 'table': a, 'context': from}, {'schema': main, 'table': b, 'context': subquery}, {'schema': main, 'table': c, 'context': subquery}] \ No newline at end of file diff --git a/test/sql/parse_tools/table_functions/parse_tables.test b/test/sql/parse_tools/table_functions/parse_tables.test index 338520b..4f6ba97 100644 --- a/test/sql/parse_tools/table_functions/parse_tables.test +++ b/test/sql/parse_tools/table_functions/parse_tables.test @@ -113,4 +113,55 @@ SELECT * FROM parse_tables('SELECT 1;'); # malformed SQL should not error query III SELECT * FROM parse_tables('SELECT * FROM WHERE'); ----- \ No newline at end of file +---- + +# WHERE clause subquery +query III +SELECT * FROM parse_tables('SELECT * FROM users WHERE id IN (SELECT user_id FROM orders)'); +---- +main users from +main orders subquery + +# WHERE clause with schema-qualified subquery +query III +SELECT * FROM parse_tables('SELECT * FROM schema1.users WHERE id IN (SELECT user_id FROM schema2.orders)'); +---- +schema1 users from +schema2 orders subquery + +# nested subqueries in WHERE +query III +SELECT * FROM parse_tables('SELECT * FROM a WHERE x IN (SELECT y FROM b WHERE z IN (SELECT w FROM c))'); +---- +main a from +main b subquery +main c subquery + +# SELECT list subquery (scalar subquery) +query III +SELECT * FROM parse_tables('SELECT id, (SELECT max(price) FROM products) AS max_price FROM orders'); +---- +main orders from +main products subquery + +# HAVING clause subquery +query III +SELECT * FROM parse_tables('SELECT dept FROM employees GROUP BY dept HAVING count(*) > (SELECT avg(cnt) FROM dept_sizes)'); +---- +main employees from +main dept_sizes subquery + +# JOIN condition subquery +query III +SELECT * FROM parse_tables('SELECT * FROM a JOIN b ON a.id = b.id AND b.status IN (SELECT code FROM statuses)'); +---- +main a from +main b join_right +main statuses subquery + +# ORDER BY clause subquery +query III +SELECT * FROM parse_tables('SELECT * FROM users ORDER BY (SELECT max(score) FROM rankings WHERE user_id = users.id)'); +---- +main users from +main rankings subquery \ No newline at end of file