Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
173ec59
feat: add structured UDT literal support with dual encoding
benbellick Nov 12, 2025
f0e13ef
test: add test to ensure struct-UDTs work with type parameters
benbellick Nov 21, 2025
f5b6341
test: improve UDT literal test by making opaque-ness explicit
benbellick Nov 21, 2025
8c54706
refactor: rename UserDefined{kind} to UserDefined{kind}Literal
benbellick Nov 24, 2025
36c81f0
feat: first class support for TypeParameters
benbellick Nov 24, 2025
680ed29
docs: update docstring to explain use of Parameter
benbellick Nov 24, 2025
6ac7aab
test: add test showing all Type.Parameters preserved
benbellick Nov 24, 2025
b2063d0
fix: correctly handle type conversion with paramterized types
benbellick Nov 24, 2025
ce7985f
refactor: revert to simpler binary construction for any rep UDT in ca…
benbellick Nov 24, 2025
e8cb862
fix: ensure nullability preserved in calcite roundtrip
benbellick Nov 25, 2025
77f0bd7
chore: addressing some PR comments
benbellick Nov 26, 2025
4b4f5ee
chore: addressing more PR comments
benbellick Nov 26, 2025
3732cf9
chore: simplify documentation
benbellick Nov 26, 2025
a216f3e
Merge branch 'main' into benbellick/handle-structured-udt2
benbellick Dec 4, 2025
f61f68c
Merge branch 'main' into benbellick/handle-structured-udt2
benbellick Dec 16, 2025
02f7920
Merge branch 'main' into benbellick/handle-structured-udt2
benbellick Jan 5, 2026
d009af4
Apply suggestions from code review
benbellick Jan 5, 2026
76cd0f0
add interface for UserDefinedAnyValue
benbellick Jan 5, 2026
a60e0b7
Revert "add interface for UserDefinedAnyValue"
benbellick Jan 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.UserDefinedAnyLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.UserDefinedStructLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.NestedStruct expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
85 changes: 79 additions & 6 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -693,21 +693,94 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

/**
* Base interface for user-defined literals.
*
* <p>User-defined literals can be encoded in one of two ways as per the Substrait spec:
*
* <ul>
* <li>As {@code google.protobuf.Any} - see {@link UserDefinedAnyLiteral}
* <li>As {@code Literal.Struct} - see {@link UserDefinedStructLiteral}
* </ul>
*/
interface UserDefinedLiteral extends Literal {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Release Notes
We should call out that we don't construct UserDefinedLiterals anymore.

String urn();

String name();

List<io.substrait.type.Type.Parameter> typeParameters();
}

/**
* User-defined literal with value encoded as {@link com.google.protobuf.Any}.
*
* <p>This encoding allows for arbitrary binary data to be stored in the literal value.
*/
@Value.Immutable
abstract class UserDefinedLiteral implements Literal {
public abstract ByteString value();
abstract class UserDefinedAnyLiteral implements UserDefinedLiteral {
@Override
public abstract String urn();

@Override
public abstract String name();

@Override
public abstract List<io.substrait.type.Type.Parameter> typeParameters();

public abstract com.google.protobuf.Any value();

@Override
public Type.UserDefined getType() {
return Type.UserDefined.builder()
.nullable(nullable())
.urn(urn())
.name(name())
.typeParameters(typeParameters())
.build();
}

public static ImmutableExpression.UserDefinedAnyLiteral.Builder builder() {
return ImmutableExpression.UserDefinedAnyLiteral.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* User-defined literal with value encoded as {@link
* io.substrait.proto.Expression.Literal.Struct}.
*
* <p>This encoding uses a structured list of fields to represent the literal value.
*/
@Value.Immutable
abstract class UserDefinedStructLiteral implements UserDefinedLiteral {
@Override
public abstract String urn();

@Override
public abstract String name();

@Override
public Type getType() {
return Type.withNullability(nullable()).userDefined(urn(), name());
public abstract List<io.substrait.type.Type.Parameter> typeParameters();

public abstract List<Literal> fields();

@Override
public Type.UserDefined getType() {
return Type.UserDefined.builder()
.nullable(nullable())
.urn(urn())
.name(name())
.typeParameters(typeParameters())
.build();
}

public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
return ImmutableExpression.UserDefinedLiteral.builder();
public static ImmutableExpression.UserDefinedStructLiteral.Builder builder() {
return ImmutableExpression.UserDefinedStructLiteral.builder();
}

@Override
Expand Down
46 changes: 42 additions & 4 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,51 @@ public static Expression.NestedStruct nestedStruct(boolean nullable, Expression.
return Expression.NestedStruct.builder().nullable(nullable).addFields(fields).build();
}

public static Expression.UserDefinedLiteral userDefinedLiteral(
boolean nullable, String urn, String name, Any value) {
return Expression.UserDefinedLiteral.builder()
/**
* Create a UserDefinedAnyLiteral with google.protobuf.Any representation.
*
* @param nullable whether the literal is nullable
* @param urn the URN of the user-defined type
* @param name the name of the user-defined type
* @param typeParameters the type parameters for the user-defined type (can be an empty list)
* @param value the value, encoded as google.protobuf.Any
*/
public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny(
boolean nullable,
String urn,
String name,
java.util.List<io.substrait.type.Type.Parameter> typeParameters,
Any value) {
return Expression.UserDefinedAnyLiteral.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.addAllTypeParameters(typeParameters)
.value(value)
.build();
}

/**
* Create a UserDefinedStructLiteral with Struct representation.
*
* @param nullable whether the literal is nullable
* @param urn the URN of the user-defined type
* @param name the name of the user-defined type
* @param typeParameters the type parameters for the user-defined type (can be an empty list)
* @param fields the fields, as a list of Literal values
*/
public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct(
boolean nullable,
String urn,
String name,
java.util.List<io.substrait.type.Type.Parameter> typeParameters,
java.util.List<Expression.Literal> fields) {
return Expression.UserDefinedStructLiteral.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.value(value.toByteString())
.addAllTypeParameters(typeParameters)
.addAllFields(fields)
.build();
}

Expand Down
14 changes: 12 additions & 2 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,24 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
R visit(Expression.NestedStruct expr, C context) throws E;

/**
* Visit a user-defined literal.
* Visit a user-defined any literal.
*
* @param expr the user-defined literal
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.UserDefinedLiteral expr, C context) throws E;
R visit(Expression.UserDefinedAnyLiteral expr, C context) throws E;

/**
* Visit a user-defined struct literal.
*
* @param expr the user-defined literal
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.UserDefinedStructLiteral expr, C context) throws E;

/**
* Visit a switch expression.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.substrait.expression.proto;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -377,21 +375,51 @@ public Expression visit(

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
try {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exception doesn't happen anymore because we don't parse the Any here. Instead, we have a reference to the pre-parsed proto directly.

bldr.setNullable(expr.nullable())
.setUserDefined(
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.setValue(Any.parseFrom(expr.value())))
.build();
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(e);
}
Expression.Literal.UserDefined.Builder userDefinedBuilder =
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.addAllTypeParameters(
expr.typeParameters().stream()
.map(typeProtoConverter::toProto)
.collect(java.util.stream.Collectors.toList()))
.setValue(expr.value());

bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedStructLiteral expr,
EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
Expression.Literal.Struct structLiteral =
Expression.Literal.Struct.newBuilder()
.addAllFields(
expr.fields().stream()
.map(this::toLiteral)
.collect(java.util.stream.Collectors.toList()))
.build();

Expression.Literal.UserDefined.Builder userDefinedBuilder =
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.addAllTypeParameters(
expr.typeParameters().stream()
.map(typeProtoConverter::toProto)
.collect(java.util.stream.Collectors.toList()))
.setStruct(structLiteral);

bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
{
io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral =
literal.getUserDefined();

SimpleExtension.Type type =
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
return ExpressionCreator.userDefinedLiteral(
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
String urn = type.urn();
String name = type.name();
List<io.substrait.type.Type.Parameter> typeParameters =
userDefinedLiteral.getTypeParametersList().stream()
.map(protoTypeConverter::from)
.collect(Collectors.toList());

switch (userDefinedLiteral.getValCase()) {
case VALUE:
return ExpressionCreator.userDefinedLiteralAny(
literal.getNullable(), urn, name, typeParameters, userDefinedLiteral.getValue());
case STRUCT:
return ExpressionCreator.userDefinedLiteralStruct(
literal.getNullable(),
urn,
name,
typeParameters,
userDefinedLiteral.getStruct().getFieldsList().stream()
.map(this::from)
.collect(Collectors.toList()));
case VAL_NOT_SET:
throw new IllegalStateException(
"UserDefined literal has no value (neither 'value' nor 'struct' is set)");
default:
throw new IllegalStateException(
"Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase());
}
}
default:
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog {
"extension:io.substrait:functions_rounding_decimal";
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";
public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types";

public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
loadDefaultCollection();
Expand All @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
.map(c -> String.format("/functions_%s.yaml", c))
.collect(Collectors.toList());

defaultFiles.add("/extension_types.yaml");

return SimpleExtension.load(defaultFiles);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ public Optional<Expression> visit(Expression.NestedStruct expr, EmptyVisitationC

@Override
public Optional<Expression> visit(
Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E {
Expression.UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws E {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(
Expression.UserDefinedStructLiteral expr, EmptyVisitationContext context) throws E {
return visitLiteral(expr);
}

Expand Down
Loading