Skip to content

Commit 9c65059

Browse files
committed
Add tests and fix models
1 parent 3e35bfd commit 9c65059

File tree

3 files changed

+168
-9
lines changed

3 files changed

+168
-9
lines changed

pyiceberg/catalog/rest/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def _plan_table_scan(self, identifier: str | Identifier, request: PlanTableScanR
424424
self._check_endpoint(Capability.V1_SUBMIT_TABLE_SCAN_PLAN)
425425
response = self._session.post(
426426
self.url(Endpoints.plan_table_scan, prefixed=True, **self._split_identifier_for_path(identifier)),
427-
json=request.model_dump(by_alias=True, exclude_none=True),
427+
data=request.model_dump_json(by_alias=True, exclude_none=True).encode(UTF8),
428428
)
429429
try:
430430
response.raise_for_status()
@@ -451,7 +451,7 @@ def _fetch_scan_tasks(self, identifier: str | Identifier, plan_task: str) -> Sca
451451
request = FetchScanTasksRequest(plan_task=plan_task)
452452
response = self._session.post(
453453
self.url(Endpoints.fetch_scan_tasks, prefixed=True, **self._split_identifier_for_path(identifier)),
454-
json=request.model_dump(by_alias=True),
454+
data=request.model_dump_json(by_alias=True).encode(UTF8),
455455
)
456456
try:
457457
response.raise_for_status()

pyiceberg/catalog/rest/scan_planning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pydantic import Field, model_validator
2525

2626
from pyiceberg.catalog.rest.response import ErrorResponseMessage
27-
from pyiceberg.expressions import BooleanExpression
27+
from pyiceberg.expressions import BooleanExpression, SerializableBooleanExpression
2828
from pyiceberg.manifest import DataFileContent, FileFormat
2929
from pyiceberg.typedef import IcebergBaseModel
3030

@@ -192,7 +192,7 @@ class PlanTableScanRequest(IcebergBaseModel):
192192

193193
snapshot_id: int | None = Field(alias="snapshot-id", default=None)
194194
select: list[str] | None = Field(default=None)
195-
filter: BooleanExpression | None = Field(default=None)
195+
filter: SerializableBooleanExpression | None = Field(default=None)
196196
case_sensitive: bool = Field(alias="case-sensitive", default=True)
197197
use_snapshot_schema: bool = Field(alias="use-snapshot-schema", default=False)
198198
start_snapshot_id: int | None = Field(alias="start-snapshot-id", default=None)

tests/integration/test_rest_scan_planning_integration.py

Lines changed: 164 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,61 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint:disable=redefined-outer-name
18+
from datetime import date, datetime, time, timedelta, timezone
19+
from decimal import Decimal
1820
from typing import Any
21+
from uuid import uuid4
1922

2023
import pyarrow as pa
2124
import pytest
2225
from pyspark.sql import SparkSession
2326

2427
from pyiceberg.catalog import Catalog, load_catalog
2528
from pyiceberg.catalog.rest import RestCatalog
26-
from pyiceberg.expressions import And, BooleanExpression, EqualTo, GreaterThan, LessThan
29+
from pyiceberg.exceptions import NoSuchTableError
30+
from pyiceberg.expressions import (
31+
And,
32+
BooleanExpression,
33+
EqualTo,
34+
GreaterThan,
35+
GreaterThanOrEqual,
36+
In,
37+
IsNull,
38+
LessThan,
39+
LessThanOrEqual,
40+
Not,
41+
NotEqualTo,
42+
NotIn,
43+
NotNull,
44+
Or,
45+
StartsWith,
46+
)
2747
from pyiceberg.partitioning import PartitionField, PartitionSpec
2848
from pyiceberg.schema import Schema
2949
from pyiceberg.table import ALWAYS_TRUE, Table
30-
from pyiceberg.transforms import IdentityTransform
31-
from pyiceberg.types import LongType, NestedField, StringType
50+
from pyiceberg.transforms import (
51+
IdentityTransform,
52+
)
53+
from pyiceberg.types import (
54+
BinaryType,
55+
BooleanType,
56+
DateType,
57+
DecimalType,
58+
DoubleType,
59+
FixedType,
60+
LongType,
61+
NestedField,
62+
StringType,
63+
TimestampType,
64+
TimestamptzType,
65+
TimeType,
66+
UUIDType,
67+
)
3268

3369

3470
@pytest.fixture(scope="session")
3571
def scan_catalog() -> Catalog:
36-
return load_catalog(
72+
catalog = load_catalog(
3773
"local",
3874
**{
3975
"type": "rest",
@@ -44,13 +80,15 @@ def scan_catalog() -> Catalog:
4480
"rest-scan-planning-enabled": "true",
4581
},
4682
)
83+
catalog.create_namespace_if_not_exists("default")
84+
return catalog
4785

4886

4987
def recreate_table(catalog: Catalog, identifier: str, **kwargs: Any) -> Table:
5088
"""Drop table if exists and create a new one."""
5189
try:
5290
catalog.drop_table(identifier)
53-
except Exception:
91+
except NoSuchTableError:
5492
pass
5593
return catalog.create_table(identifier, **kwargs)
5694

@@ -186,3 +224,124 @@ def test_rest_scan_with_partitioning(scan_catalog: RestCatalog, session_catalog:
186224
)
187225
finally:
188226
scan_catalog.drop_table(identifier)
227+
228+
229+
@pytest.mark.integration
230+
def test_rest_scan_primitive_types(scan_catalog: RestCatalog, session_catalog: Catalog) -> None:
231+
identifier = "default.test_primitives"
232+
233+
schema = Schema(
234+
NestedField(1, "bool_col", BooleanType()),
235+
NestedField(2, "long_col", LongType()),
236+
NestedField(3, "double_col", DoubleType()),
237+
NestedField(4, "decimal_col", DecimalType(10, 2)),
238+
NestedField(5, "string_col", StringType()),
239+
NestedField(6, "date_col", DateType()),
240+
NestedField(7, "time_col", TimeType()),
241+
NestedField(8, "timestamp_col", TimestampType()),
242+
NestedField(9, "timestamptz_col", TimestamptzType()),
243+
NestedField(10, "uuid_col", UUIDType()),
244+
NestedField(11, "fixed_col", FixedType(16)),
245+
NestedField(12, "binary_col", BinaryType()),
246+
)
247+
248+
table = recreate_table(scan_catalog, identifier, schema=schema)
249+
250+
now = datetime.now()
251+
now_tz = datetime.now(tz=timezone.utc)
252+
today = date.today()
253+
uuid1, uuid2, uuid3 = uuid4(), uuid4(), uuid4()
254+
255+
arrow_table = pa.Table.from_pydict(
256+
{
257+
"bool_col": [True, False, True],
258+
"long_col": [100, 200, 300],
259+
"double_col": [1.11, 2.22, 3.33],
260+
"decimal_col": [Decimal("1.23"), Decimal("4.56"), Decimal("7.89")],
261+
"string_col": ["a", "b", "c"],
262+
"date_col": [today, today - timedelta(days=1), today - timedelta(days=2)],
263+
"time_col": [time(8, 30, 0), time(12, 0, 0), time(18, 45, 30)],
264+
"timestamp_col": [now, now - timedelta(hours=1), now - timedelta(hours=2)],
265+
"timestamptz_col": [now_tz, now_tz - timedelta(hours=1), now_tz - timedelta(hours=2)],
266+
"uuid_col": [uuid1.bytes, uuid2.bytes, uuid3.bytes],
267+
"fixed_col": [b"0123456789abcdef", b"abcdef0123456789", b"fedcba9876543210"],
268+
"binary_col": [b"hello", b"world", b"test"],
269+
},
270+
schema=schema.as_arrow(),
271+
)
272+
table.append(arrow_table)
273+
274+
try:
275+
_assert_remote_scan_matches_local_scan(table, session_catalog, identifier)
276+
finally:
277+
scan_catalog.drop_table(identifier)
278+
279+
280+
@pytest.mark.integration
281+
def test_rest_scan_complex_filters(scan_catalog: RestCatalog, session_catalog: Catalog) -> None:
282+
identifier = "default.test_complex_filters"
283+
284+
schema = Schema(
285+
NestedField(1, "id", LongType()),
286+
NestedField(2, "name", StringType()),
287+
NestedField(3, "value", LongType()),
288+
NestedField(4, "optional", StringType(), required=False),
289+
)
290+
291+
table = recreate_table(scan_catalog, identifier, schema=schema)
292+
293+
table.append(
294+
pa.Table.from_pydict(
295+
{
296+
"id": list(range(1, 21)),
297+
"name": [f"item_{i}" for i in range(1, 21)],
298+
"value": [i * 100 for i in range(1, 21)],
299+
"optional": [None if i % 3 == 0 else f"opt_{i}" for i in range(1, 21)],
300+
}
301+
)
302+
)
303+
304+
try:
305+
filters = [
306+
EqualTo("id", 10),
307+
NotEqualTo("id", 10),
308+
GreaterThan("value", 1000),
309+
GreaterThanOrEqual("value", 1000),
310+
LessThan("value", 500),
311+
LessThanOrEqual("value", 500),
312+
In("id", [1, 5, 10, 15]),
313+
NotIn("id", [1, 5, 10, 15]),
314+
IsNull("optional"),
315+
NotNull("optional"),
316+
StartsWith("name", "item_1"),
317+
And(GreaterThan("id", 5), LessThan("id", 15)),
318+
Or(EqualTo("id", 1), EqualTo("id", 20)),
319+
Not(EqualTo("id", 10)),
320+
]
321+
322+
for filter_expr in filters:
323+
_assert_remote_scan_matches_local_scan(table, session_catalog, identifier, row_filter=filter_expr)
324+
finally:
325+
scan_catalog.drop_table(identifier)
326+
327+
328+
@pytest.mark.integration
329+
def test_rest_scan_empty_table(scan_catalog: RestCatalog, session_catalog: Catalog) -> None:
330+
identifier = "default.test_empty_table"
331+
332+
schema = Schema(
333+
NestedField(1, "id", LongType()),
334+
NestedField(2, "data", StringType()),
335+
)
336+
337+
table = recreate_table(scan_catalog, identifier, schema=schema)
338+
339+
try:
340+
rest_tasks = list(table.scan().plan_files())
341+
local_table = session_catalog.load_table(identifier)
342+
local_tasks = list(local_table.scan().plan_files())
343+
344+
assert len(rest_tasks) == 0
345+
assert len(local_tasks) == 0
346+
finally:
347+
scan_catalog.drop_table(identifier)

0 commit comments

Comments
 (0)