From 3d944efbd542eb5f2cc06c493371ce55736f81ca Mon Sep 17 00:00:00 2001 From: ashley-taylor Date: Sat, 5 Jul 2025 22:35:25 +1200 Subject: [PATCH 1/2] AVRO-4165: ability to specify AvroEncode on a class --- .../org/apache/avro/reflect/AvroEncode.java | 2 +- .../avro/reflect/FieldAccessReflect.java | 4 +- .../org/apache/avro/reflect/ReflectData.java | 37 +++- .../avro/reflect/ReflectDatumReader.java | 14 ++ .../avro/reflect/ReflectDatumWriter.java | 12 +- .../apache/avro/reflect/ReflectionUtil.java | 25 +++ .../apache/avro/reflect/TestAvroEncode.java | 180 ++++++++++++++++++ 7 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 lang/java/avro/src/test/java/org/apache/avro/reflect/TestAvroEncode.java diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java index 225f247a9ed..b9d7f00a707 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java @@ -30,7 +30,7 @@ * file. Use of {@link org.apache.avro.io.ValidatingEncoder} is recommended. */ @Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.FIELD) +@Target({ ElementType.FIELD, ElementType.TYPE }) public @interface AvroEncode { Class> using(); } diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java index df258f9d50d..72d0563290b 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java @@ -28,7 +28,7 @@ class FieldAccessReflect extends FieldAccess { @Override protected FieldAccessor getAccessor(Field field) { - AvroEncode enc = field.getAnnotation(AvroEncode.class); + AvroEncode enc = ReflectionUtil.getAvroEncode(field); if (enc != null) try { return new ReflectionBasesAccessorCustomEncoded(field, enc.using().getDeclaredConstructor().newInstance()); @@ -47,7 +47,7 @@ public ReflectionBasedAccessor(Field field) { this.field = field; this.field.setAccessible(true); isStringable = field.isAnnotationPresent(Stringable.class); - isCustomEncoded = field.isAnnotationPresent(AvroEncode.class); + isCustomEncoded = ReflectionUtil.getAvroEncode(field) != null; } @Override diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java index aa15ee8f46d..20ad9a7811f 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java @@ -69,6 +69,9 @@ public class ReflectData extends SpecificData { private static final String STRING_OUTER_PARENT_REFERENCE = "this$0"; + // holds a wrapper so null entries will have a cached value + private final ConcurrentMap encoderCache = new ConcurrentHashMap<>(); + /** * Always false since custom coders are not available for {@link ReflectData}. */ @@ -864,7 +867,7 @@ private static Field[] getFields(Class recordClass, boolean excludeJava) { /** Create a schema for a field. */ protected Schema createFieldSchema(Field field, Map names) { - AvroEncode enc = field.getAnnotation(AvroEncode.class); + AvroEncode enc = ReflectionUtil.getAvroEncode(field); if (enc != null) try { return enc.using().getDeclaredConstructor().newInstance().getSchema(); @@ -1042,4 +1045,36 @@ public Object newRecord(Object old, Schema schema) { } return super.newRecord(old, schema); } + + public CustomEncoding getCustomEncoding(Schema schema) { + + return this.encoderCache.computeIfAbsent(schema, this::populateEncoderCache).get(); + } + + private CustomEncodingWrapper populateEncoderCache(Schema schema) { + var enc = ReflectionUtil.getAvroEncode(getClass(schema)); + if (enc != null) { + try { + return new CustomEncodingWrapper(enc.using().getDeclaredConstructor().newInstance()); + } catch (Exception e) { + throw new AvroRuntimeException("Could not instantiate custom Encoding"); + } + } + return new CustomEncodingWrapper(null); + } + + private class CustomEncodingWrapper { + + private final CustomEncoding customEncoding; + + private CustomEncodingWrapper(CustomEncoding customEncoding) { + this.customEncoding = customEncoding; + } + + public CustomEncoding get() { + return customEncoding; + } + + } + } diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumReader.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumReader.java index 2a8fcee9f26..7ba8e4827c6 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumReader.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumReader.java @@ -73,6 +73,10 @@ public ReflectDatumReader(ReflectData data) { super(data); } + private ReflectData getReflectData() { + return (ReflectData) getSpecificData(); + } + @Override protected Object newArray(Object old, int size, Schema schema) { Class collectionClass = ReflectData.getClassProp(schema, SpecificData.CLASS_PROP); @@ -251,6 +255,16 @@ protected Object readBytes(Object old, Schema s, Decoder in) throws IOException } } + @Override + protected Object read(Object old, Schema expected, ResolvingDecoder in) throws IOException { + CustomEncoding encoder = getReflectData().getCustomEncoding(expected); + if (encoder != null) { + return encoder.read(old, in); + } else { + return super.read(old, expected, in); + } + } + @Override protected Object readInt(Object old, Schema expected, Decoder in) throws IOException { Object value = in.readInt(); diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumWriter.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumWriter.java index 25555d99e47..b9b083fd6b2 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumWriter.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectDatumWriter.java @@ -61,6 +61,10 @@ protected ReflectDatumWriter(ReflectData reflectData) { super(reflectData); } + private ReflectData getReflectData() { + return (ReflectData) getSpecificData(); + } + /** * Called to write a array. May be overridden for alternate array * representations. @@ -158,7 +162,13 @@ else if (datum instanceof Map && ReflectData.isNonStringMapSchema(schema)) { datum = ((Optional) datum).orElse(null); } try { - super.write(schema, datum, out); + + CustomEncoding encoder = getReflectData().getCustomEncoding(schema); + if (encoder != null) { + encoder.write(datum, out); + } else { + super.write(schema, datum, out); + } } catch (NullPointerException e) { // improve error message throw npe(e, " in " + schema.getFullName()); } diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java index 4fa52d0345e..c374facf756 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java @@ -24,6 +24,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; @@ -188,4 +189,28 @@ public static Function getConstructorAsFunction(Class parameterC } } + protected static AvroEncode getAvroEncode(Field field) { + var enc = field.getAnnotation(AvroEncode.class); + if (enc != null) { + return enc; + } else { + return getAvroEncode(field.getType()); + } + } + + protected static AvroEncode getAvroEncode(Class clazz) { + if (clazz == null) { + return null; + } + AvroEncode enc = clazz.getAnnotation(AvroEncode.class); + if (enc != null) { + return enc; + } + // try superclasses + Class superclass = clazz.getSuperclass(); + if (superclass != null && superclass != Object.class) { + return getAvroEncode(superclass); + } + return null; + } } diff --git a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestAvroEncode.java b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestAvroEncode.java new file mode 100644 index 00000000000..daee2a39a96 --- /dev/null +++ b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestAvroEncode.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.avro.reflect; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; + +import org.apache.avro.AvroTypeException; +import org.apache.avro.Schema; +import org.apache.avro.io.Decoder; +import org.apache.avro.io.DecoderFactory; +import org.apache.avro.io.Encoder; +import org.apache.avro.io.EncoderFactory; +import org.junit.jupiter.api.Test; + +public class TestAvroEncode { + EncoderFactory factory = new EncoderFactory(); + + @Test + void testWithinClass() throws IOException { + + var wrapper = new Wrapper(new R1("test")); + + var read = readWrite(wrapper); + + assertEquals("test", wrapper.getR1().getValue()); + assertEquals("test used this", read.getR1().getValue()); + } + + @Test + void testDirect() throws IOException { + + var r1 = new R1("test"); + + var read = readWrite(r1); + + assertEquals("test", r1.getValue()); + assertEquals("test used this", read.getValue()); + } + + @Test + void testFieldAnnotationTakesPrecedence() throws IOException { + + var wrapper = new OtherWrapper(new R1("test")); + + var read = readWrite(wrapper); + + assertEquals("test", wrapper.getR1().getValue()); + assertEquals("test used other", read.getR1().getValue()); + } + + public static class Wrapper { + + private R1 r1; + + public Wrapper() { + } + + public Wrapper(R1 r1) { + this.r1 = r1; + } + + public R1 getR1() { + return r1; + } + + public void setR1(R1 r1) { + this.r1 = r1; + } + + } + + public static class OtherWrapper { + @AvroEncode(using = R1EncodingOther.class) + private R1 r1; + + public OtherWrapper() { + } + + public OtherWrapper(R1 r1) { + this.r1 = r1; + } + + public R1 getR1() { + return r1; + } + + public void setR1(R1 r1) { + this.r1 = r1; + } + + } + + @AvroEncode(using = R1Encoding.class) + public static class R1 { + + private final String value; + + public R1(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + } + + public static class R1Encoding extends CustomEncoding { + + { + schema = Schema.createRecord("R1", null, null, false, + Arrays.asList(new Schema.Field("value", Schema.create(Schema.Type.STRING), null, null))); + } + + @Override + protected void write(Object datum, Encoder out) throws IOException { + if (datum instanceof R1) { + out.writeString(((R1) datum).getValue()); + } else { + throw new AvroTypeException("Expected R1, got " + datum.getClass()); + } + + } + + @Override + protected R1 read(Object reuse, Decoder in) throws IOException { + return new R1(in.readString() + " used this"); + } + } + + public static class R1EncodingOther extends CustomEncoding { + + { + schema = Schema.createRecord("R1", null, null, false, + Arrays.asList(new Schema.Field("value", Schema.create(Schema.Type.STRING), null, null))); + } + + @Override + protected void write(Object datum, Encoder out) throws IOException { + if (datum instanceof R1) { + out.writeString(((R1) datum).getValue()); + } else { + throw new AvroTypeException("Expected R1, got " + datum.getClass()); + } + } + + @Override + protected R1 read(Object reuse, Decoder in) throws IOException { + return new R1(in.readString() + " used other"); + } + } + + T readWrite(T object) throws IOException { + var schema = new ReflectData().getSchema(object.getClass()); + ReflectDatumWriter writer = new ReflectDatumWriter<>(schema); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writer.write(object, factory.directBinaryEncoder(out, null)); + ReflectDatumReader reader = new ReflectDatumReader<>(schema); + return reader.read(null, DecoderFactory.get().binaryDecoder(out.toByteArray(), null)); + } +} From 40829063960e9acfd78f2aac6ec53da1e93e05e1 Mon Sep 17 00:00:00 2001 From: ashley-taylor Date: Sat, 19 Jul 2025 09:38:15 +1200 Subject: [PATCH 2/2] PR feedback --- .../main/java/org/apache/avro/reflect/AvroEncode.java | 2 ++ .../java/org/apache/avro/reflect/ReflectData.java | 2 +- .../java/org/apache/avro/reflect/ReflectionUtil.java | 11 +---------- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java index b9d7f00a707..b4a021dce79 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/AvroEncode.java @@ -18,6 +18,7 @@ package org.apache.avro.reflect; import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; @@ -30,6 +31,7 @@ * file. Use of {@link org.apache.avro.io.ValidatingEncoder} is recommended. */ @Retention(RetentionPolicy.RUNTIME) +@Inherited @Target({ ElementType.FIELD, ElementType.TYPE }) public @interface AvroEncode { Class> using(); diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java index 20ad9a7811f..4b993c6fddc 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java @@ -1063,7 +1063,7 @@ private CustomEncodingWrapper populateEncoderCache(Schema schema) { return new CustomEncodingWrapper(null); } - private class CustomEncodingWrapper { + private static class CustomEncodingWrapper { private final CustomEncoding customEncoding; diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java index c374facf756..3221d91d1f2 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectionUtil.java @@ -202,15 +202,6 @@ protected static AvroEncode getAvroEncode(Class clazz) { if (clazz == null) { return null; } - AvroEncode enc = clazz.getAnnotation(AvroEncode.class); - if (enc != null) { - return enc; - } - // try superclasses - Class superclass = clazz.getSuperclass(); - if (superclass != null && superclass != Object.class) { - return getAvroEncode(superclass); - } - return null; + return clazz.getAnnotation(AvroEncode.class); } }