From 8c103bfb73193ec39ba4c7a95d540e63ed5801e8 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 15 Dec 2025 13:18:16 +0100 Subject: [PATCH] Limit string nodes in Polars expressions to constant expressions --- _duckdb-stubs/__init__.pyi | 9 +++------ duckdb/polars_io.py | 5 +++-- tests/fast/arrow/test_polars.py | 17 +++++++++++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi index 75389791..20e0eca0 100644 --- a/_duckdb-stubs/__init__.pyi +++ b/_duckdb-stubs/__init__.pyi @@ -1036,15 +1036,12 @@ class token_type: def CaseExpression(condition: Expression, value: Expression) -> Expression: ... def CoalesceOperator(*args: Expression) -> Expression: ... def ColumnExpression(*args: str) -> Expression: ... -def ConstantExpression(value: Expression | str) -> Expression: ... +def ConstantExpression(value: pytyping.Any) -> Expression: ... def DefaultExpression() -> Expression: ... def FunctionExpression(function_name: str, *args: Expression) -> Expression: ... -def LambdaExpression(lhs: Expression | str | tuple[str], rhs: Expression) -> Expression: ... +def LambdaExpression(lhs: pytyping.Any, rhs: Expression) -> Expression: ... def SQLExpression(expression: str) -> Expression: ... -@pytyping.overload -def StarExpression(*, exclude: Expression | str | tuple[str]) -> Expression: ... -@pytyping.overload -def StarExpression() -> Expression: ... +def StarExpression(*, exclude: pytyping.Any = None) -> Expression: ... def aggregate( df: pandas.DataFrame, aggr_expr: Expression | list[Expression] | str | list[str], diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index 2c075baf..a7ed84ff 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -236,8 +236,9 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: # String type if dtype == "String" or dtype == "StringOwned": # Some new formats may store directly under StringOwned - string_val: object | None = value.get("StringOwned", value.get("String", None)) - return f"'{string_val}'" + string_val = value.get("StringOwned", value.get("String", None)) + # the string must be a string constant + return str(duckdb.ConstantExpression(string_val)) msg = f"Unsupported scalar type {dtype!s}, with value {value}" raise NotImplementedError(msg) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 822022a3..0eba5eeb 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -639,6 +639,23 @@ def test_polars_lazy_many_batches(self, duckdb_cursor): assert res == correct + @pytest.mark.parametrize( + "input_str", ["A'dam", 'answer = "42"', "'; DROP TABLE users; --", "line1\nline2\ttab", "", None] + ) + def test_expr_with_sql_in_string_node(self, input_str): + """SQL in a String node in an expression is treated as a constant expression.""" + expected = str(duckdb.ConstantExpression(input_str)) + + # Regular string + tree = {"Scalar": {"String": input_str}} + result = _pl_tree_to_sql(tree) + assert result == expected + + # StringOwned + tree = {"Scalar": {"StringOwned": input_str}} + result = _pl_tree_to_sql(tree) + assert result == expected + def test_invalid_expr_json(self): bad_key_expr = """ {