diff --git a/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/IcebergWork.java b/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/IcebergWork.java index 1762d4b81a0..796efc0fc70 100644 --- a/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/IcebergWork.java +++ b/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/IcebergWork.java @@ -33,8 +33,10 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.util.Base64; import java.util.Objects; import java.util.StringJoiner; @@ -92,7 +94,8 @@ public IcebergWorkDeserializer() { public IcebergWork deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { JsonNode node = p.getCodec().readTree(p); String scanTaskString = node.get(IcebergWorkSerializer.SCAN_TASK_FIELD).asText(); - try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(Base64.getDecoder().decode(scanTaskString)))) { + try (ObjectInputStream ois = new ScanTaskObjectInputStream( + new ByteArrayInputStream(Base64.getDecoder().decode(scanTaskString)))) { Object scanTask = ois.readObject(); return new IcebergWork((CombinedScanTask) scanTask); } catch (ClassNotFoundException e) { @@ -103,6 +106,35 @@ public IcebergWork deserialize(JsonParser p, DeserializationContext ctxt) throws } } + private static class ScanTaskObjectInputStream extends ObjectInputStream { + + ScanTaskObjectInputStream(InputStream inputStream) throws IOException { + super(inputStream); + } + + @Override + protected Class resolveClass(ObjectStreamClass cls) throws IOException, ClassNotFoundException { + final String className = cls.getName(); + if (isValidPackage(className)) { + return super.resolveClass(cls); + } + final Class resolvedClass = super.resolveClass(cls); + if ((resolvedClass.isArray() && + (resolvedClass.getComponentType().isPrimitive() || + isValidPackage(resolvedClass.getComponentType().getName()))) + || resolvedClass.isPrimitive()) { + return resolvedClass; + } + throw new IOException("Rejected deserialization of unexpected class: " + className); + } + + private boolean isValidPackage(final String className) { + return className.startsWith("org.apache.iceberg.") || + className.startsWith("org.apache.drill.") || + className.startsWith("java."); + } + } + /** * Special serializer for {@link IcebergWork} class that serializes * {@code scanTask} field to byte array string created using {@link java.io.Serializable} diff --git a/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/format/IcebergFormatPluginConfig.java b/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/format/IcebergFormatPluginConfig.java index 9e53803883c..5acd9457026 100644 --- a/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/format/IcebergFormatPluginConfig.java +++ b/contrib/format-iceberg/src/main/java/org/apache/drill/exec/store/iceberg/format/IcebergFormatPluginConfig.java @@ -45,6 +45,8 @@ public class IcebergFormatPluginConfig implements FormatPluginConfig { private final Boolean ignoreResiduals; + private final Boolean allowAnyClassToBeLoaded; + private final Long snapshotId; private final Long snapshotAsOfTime; @@ -60,6 +62,7 @@ public IcebergFormatPluginConfig( this.caseSensitive = builder.caseSensitive; this.includeColumnStats = builder.includeColumnStats; this.ignoreResiduals = builder.ignoreResiduals; + this.allowAnyClassToBeLoaded = builder.allowAnyClassToBeLoaded; this.snapshotId = builder.snapshotId; this.snapshotAsOfTime = builder.snapshotAsOfTime; this.fromSnapshotId = builder.fromSnapshotId; @@ -100,6 +103,10 @@ public Boolean getIgnoreResiduals() { return this.ignoreResiduals; } + public Boolean getAllowAnyClassToBeLoaded() { + return this.allowAnyClassToBeLoaded; + } + public Long getSnapshotId() { return this.snapshotId; } @@ -130,6 +137,7 @@ public boolean equals(Object o) { && Objects.equals(caseSensitive, that.caseSensitive) && Objects.equals(includeColumnStats, that.includeColumnStats) && Objects.equals(ignoreResiduals, that.ignoreResiduals) + && Objects.equals(allowAnyClassToBeLoaded, that.allowAnyClassToBeLoaded) && Objects.equals(snapshotId, that.snapshotId) && Objects.equals(snapshotAsOfTime, that.snapshotAsOfTime) && Objects.equals(fromSnapshotId, that.fromSnapshotId) @@ -138,8 +146,8 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(properties, snapshot, caseSensitive, includeColumnStats, - ignoreResiduals, snapshotId, snapshotAsOfTime, fromSnapshotId, toSnapshotId); + return Objects.hash(properties, snapshot, caseSensitive, includeColumnStats, ignoreResiduals, + allowAnyClassToBeLoaded, snapshotId, snapshotAsOfTime, fromSnapshotId, toSnapshotId); } @JsonPOJOBuilder(withPrefix = "") @@ -152,6 +160,8 @@ public static class IcebergFormatPluginConfigBuilder { private Boolean ignoreResiduals; + private Boolean allowAnyClassToBeLoaded; + private Long snapshotId; private Long snapshotAsOfTime; @@ -180,6 +190,11 @@ public IcebergFormatPluginConfigBuilder ignoreResiduals(Boolean ignoreResiduals) return this; } + public IcebergFormatPluginConfigBuilder allowAnyClassToBeLoaded(Boolean allowAnyClassToBeLoaded) { + this.allowAnyClassToBeLoaded = allowAnyClassToBeLoaded; + return this; + } + public IcebergFormatPluginConfigBuilder snapshotId(Long snapshotId) { this.snapshotId = snapshotId; return this;