Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions tests/fast/arrow/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

import pytest
from packaging.version import parse as parse_version

import duckdb

Expand All @@ -11,6 +12,8 @@

from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402

pl_pre_1_36_0 = parse_version(pl.__version__) < parse_version("1.36.0")


def valid_filter(filter):
sql_expression = _predicate_to_expression(filter)
Expand Down Expand Up @@ -86,10 +89,18 @@ def test_polars_from_json(self, duckdb_cursor):
res = duckdb_cursor.read_json(string).pl()
assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}"

@pytest.mark.skipif(
not hasattr(pl.exceptions, "PanicException"), reason="Polars has no PanicException in this version"
)
def test_polars_from_json_error(self, duckdb_cursor):
@pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 doesn't support arrow extensions")
def test_polars_from_json_post_pl_1_36_0(self, duckdb_cursor):
from io import StringIO

duckdb_cursor.sql("set arrow_lossless_conversion=true")
string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""")
pl.register_extension_type("arrow.json", pl.Extension)
res = duckdb_cursor.read_json(string).pl()
assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}"

@pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 supports arrow extensions")
def test_polars_from_json_pre_pl_1_36_0(self, duckdb_cursor):
from io import StringIO

duckdb_cursor.sql("set arrow_lossless_conversion=true")
Expand Down Expand Up @@ -426,13 +437,34 @@ def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor):
lazy_df.filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 2
)

# Validate Filter
@pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates")
def test_polars_predicate_to_expression_post_1_36_0(self):
ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1)
ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1)
ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1)
# Validate filters - none of these produce casts in polars >= 1.36.0
valid_filter(pl.col("a") == ts_2008)
valid_filter(pl.col("a") > ts_2008)
valid_filter(pl.col("a") >= ts_2010)
valid_filter(pl.col("a") < ts_2010)
valid_filter(pl.col("a") <= ts_2010)
valid_filter(pl.col("a").is_null())
valid_filter(pl.col("a").is_not_null())
valid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008))
valid_filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020))
valid_filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008))

@pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts")
def test_polars_predicate_to_expression_pre_1_36_0(self):
ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1)
ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1)
ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1)
# Validate filters
invalid_filter(pl.col("a") == ts_2008)
invalid_filter(pl.col("a") > ts_2008)
invalid_filter(pl.col("a") >= ts_2010)
invalid_filter(pl.col("a") < ts_2010)
invalid_filter(pl.col("a") <= ts_2010)
# These two are actually valid because they don't produce a cast
valid_filter(pl.col("a").is_null())
valid_filter(pl.col("a").is_not_null())
invalid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008))
Expand Down
Loading