Skip to content

Commit c909f06

Browse files
authored
fix: add nullability check for NamedStruct in builder (#104)
ensures that the nullability of a NamedStruct is NULLABILITY_REQUIRED --------- Signed-off-by: MBWhite <whitemat@uk.ibm.com>
1 parent 42e979b commit c909f06

File tree

15 files changed

+101
-13
lines changed

15 files changed

+101
-13
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,8 @@ codegen-extensions:
1414
lint:
1515
uvx ruff@0.11.11 check
1616

17+
lint_fix:
18+
uvx ruff@0.11.11 check --fix
19+
1720
format:
1821
uvx ruff@0.11.11 format

src/substrait/builders/plan.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def _merge_extensions(*objs):
3434
def read_named_table(
3535
names: Union[str, Iterable[str]], named_struct: stt.NamedStruct
3636
) -> UnboundPlan:
37+
if named_struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE:
38+
raise Exception("NamedStruct must not contain a nullable struct")
39+
elif named_struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED:
40+
named_struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED
41+
3742
def resolve(registry: ExtensionRegistry) -> stp.Plan:
3843
_names = [names] if isinstance(names, str) else names
3944

src/substrait/builders/type.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,9 @@ def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type:
257257

258258

259259
def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct:
260+
if struct.struct.nullability is stt.Type.NULLABILITY_NULLABLE:
261+
raise Exception("NamedStruct must not contain a nullable struct")
262+
elif struct.struct.nullability is stt.Type.NULLABILITY_UNSPECIFIED:
263+
struct.struct.nullability = stt.Type.NULLABILITY_REQUIRED
264+
260265
return stt.NamedStruct(names=names, struct=struct.struct)

tests/builders/extended_expression/test_aggregate_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
1313
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)),
1414
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
15-
]
15+
],
16+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
1617
)
1718

1819
named_struct = stt.NamedStruct(

tests/builders/extended_expression/test_scalar_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_sclar_add():
6868
extensions=[
6969
ste.SimpleExtensionDeclaration(
7070
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
71-
extension_uri_reference=1, function_anchor=1, name="test_func:i8"
71+
extension_uri_reference=1, function_anchor=1, name="test_func:i8"
7272
)
7373
)
7474
],

tests/builders/plan/test_aggregate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
registry = ExtensionRegistry(load_default_extensions=False)
2929
registry.register_extension_dict(yaml.safe_load(content), uri="test_uri")
3030

31-
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
31+
struct = stt.Type.Struct(
32+
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
33+
)
3234

3335
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
3436

tests/builders/plan/test_cross.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77

88
registry = ExtensionRegistry(load_default_extensions=False)
99

10-
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
10+
struct = stt.Type.Struct(
11+
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
12+
)
1113

1214
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
1315

1416
named_struct_2 = stt.NamedStruct(
1517
names=["fk_id", "name"],
16-
struct=stt.Type.Struct(types=[i64(nullable=False), string()]),
18+
struct=stt.Type.Struct(
19+
types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED
20+
),
1721
)
1822

1923

tests/builders/plan/test_fetch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
registry = ExtensionRegistry(load_default_extensions=False)
1010

11-
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
11+
struct = stt.Type.Struct(
12+
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
13+
)
1214

1315
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
1416

tests/builders/plan/test_filter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
registry = ExtensionRegistry(load_default_extensions=False)
1010

11-
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
11+
struct = stt.Type.Struct(
12+
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
13+
)
1214

1315
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
1416

tests/builders/plan/test_join.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@
88

99
registry = ExtensionRegistry(load_default_extensions=False)
1010

11-
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
11+
struct = stt.Type.Struct(
12+
types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED
13+
)
1214

1315
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
1416

1517
named_struct_2 = stt.NamedStruct(
1618
names=["fk_id", "name"],
17-
struct=stt.Type.Struct(types=[i64(nullable=False), string()]),
19+
struct=stt.Type.Struct(
20+
types=[i64(nullable=False), string()], nullability=stt.Type.NULLABILITY_REQUIRED
21+
),
1822
)
1923

2024

0 commit comments

Comments
 (0)