From bd3724ec17a1c9412c903545d61d6e2e11e6a72e Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Fri, 27 Dec 2024 22:37:22 +0000 Subject: [PATCH 01/15] use ByteBuddy code generation to write proto to parquet faster --- parquet-protobuf/pom.xml | 8 + .../parquet/proto/ByteBuddyCodeGen.java | 2729 +++++++++++++++++ .../parquet/proto/ProtoWriteSupport.java | 146 +- 3 files changed, 2881 insertions(+), 2 deletions(-) create mode 100644 parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java diff --git a/parquet-protobuf/pom.xml b/parquet-protobuf/pom.xml index f704295eff..e547cf6946 100644 --- a/parquet-protobuf/pom.xml +++ b/parquet-protobuf/pom.xml @@ -34,6 +34,7 @@ 3.25.5 2.50.0 1.4.3 + 1.14.18 @@ -41,6 +42,13 @@ https://parquet.apache.org + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + compile + true + org.mockito mockito-core diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java new file mode 100644 index 0000000000..1f869c6de9 --- /dev/null +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -0,0 +1,2729 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.proto; + +import static org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Reflection; + +import com.google.common.collect.MapMaker; +import com.google.protobuf.BoolValue; +import com.google.protobuf.ByteString; +import com.google.protobuf.BytesValue; +import com.google.protobuf.Descriptors; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.FloatValue; +import com.google.protobuf.Int32Value; +import com.google.protobuf.Int64Value; +import com.google.protobuf.Message; +import com.google.protobuf.MessageOrBuilder; +import com.google.protobuf.StringValue; +import com.google.protobuf.Timestamp; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; +import com.google.protobuf.util.Timestamps; +import com.google.type.Date; +import com.google.type.TimeOfDay; +import java.lang.invoke.LambdaMetafactory; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Type; +import java.time.LocalDate; +import java.time.LocalTime; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.field.FieldDescription; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.modifier.Visibility; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.DynamicType; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.dynamic.scaffold.InstrumentedType; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.SuperMethodCall; +import net.bytebuddy.implementation.bytecode.ByteCodeAppender; +import net.bytebuddy.implementation.bytecode.Removal; +import net.bytebuddy.implementation.bytecode.StackManipulation; +import net.bytebuddy.implementation.bytecode.StackSize; +import net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import net.bytebuddy.implementation.bytecode.collection.ArrayAccess; +import net.bytebuddy.implementation.bytecode.collection.ArrayFactory; +import net.bytebuddy.implementation.bytecode.constant.IntegerConstant; +import net.bytebuddy.implementation.bytecode.constant.JavaConstantValue; +import net.bytebuddy.implementation.bytecode.constant.TextConstant; +import net.bytebuddy.implementation.bytecode.member.FieldAccess; +import net.bytebuddy.implementation.bytecode.member.MethodInvocation; +import net.bytebuddy.implementation.bytecode.member.MethodReturn; +import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess; +import net.bytebuddy.jar.asm.Handle; +import net.bytebuddy.jar.asm.Label; +import net.bytebuddy.jar.asm.MethodVisitor; +import net.bytebuddy.jar.asm.Opcodes; +import net.bytebuddy.matcher.ElementMatchers; +import net.bytebuddy.utility.JavaConstant; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.io.api.RecordConsumer; +import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Codegen; +import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Implementations; +import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.LocalVar; +import org.apache.parquet.schema.MessageType; + +class ByteBuddyCodeGen { + private static final AtomicLong BYTE_BUDDY_CLASS_SEQUENCE = new AtomicLong(); + + private static final GenerateMessageClasses GeneratedMessageV3 = + GenerateMessageClasses.resolve("com.google.protobuf.GeneratedMessageV3"); + private static final GenerateMessageClasses GeneratedMessage = + GenerateMessageClasses.resolve("com.google.protobuf.GeneratedMessage"); + + static class CodeGenException extends RuntimeException { + public CodeGenException() { + super(); + } + + public CodeGenException(String message) { + super(message); + } + + public CodeGenException(String message, Throwable cause) { + super(message, cause); + } + + public CodeGenException(Throwable cause) { + super(cause); + } + } + + static boolean isGeneratedMessage(Class protoMessage) { + return protoMessage != null + && (GeneratedMessage.isGeneratedMessage(protoMessage) + || GeneratedMessageV3.isGeneratedMessage(protoMessage)); + } + + static boolean isExtendableMessage(Class protoMessage) { + return protoMessage != null + && (GeneratedMessage.isExtendableMessage(protoMessage) + || GeneratedMessageV3.isExtendableMessage(protoMessage)); + } + + static class GenerateMessageClasses { + private final Class classGeneratedMessage; + private final Class classExtendableMessage; + + private GenerateMessageClasses(Class classGeneratedMessage, Class classExtendableMessage) { + this.classGeneratedMessage = classGeneratedMessage; + this.classExtendableMessage = classExtendableMessage; + } + + static GenerateMessageClasses resolve(String generatedMessageClassName) { + Optional> generatedMessage = ReflectionUtil.classForName(generatedMessageClassName); + Optional> extendableMessage = + ReflectionUtil.classForName(generatedMessageClassName + "$ExtendableMessage"); + if (generatedMessage.isPresent() && extendableMessage.isPresent()) { + return new GenerateMessageClasses(generatedMessage.get(), extendableMessage.get()); + } else { + return new GenerateMessageClasses(null, null); + } + } + + public boolean isGeneratedMessage(Class clazz) { + return classGeneratedMessage != null && clazz != null && classGeneratedMessage.isAssignableFrom(clazz); + } + + public boolean isExtendableMessage(Class clazz) { + return classExtendableMessage != null && clazz != null && classExtendableMessage.isAssignableFrom(clazz); + } + } + + static boolean isByteBuddyAvailable(boolean failIfNot) { + try { + Class.forName("net.bytebuddy.ByteBuddy", false, ByteBuddyCodeGen.class.getClassLoader()); + return true; + } catch (ClassNotFoundException e) { + if (failIfNot) { + throw new IllegalStateException("ByteBuddy is not available", e); + } + return false; + } + } + + static class CodeGenUtils { + // resolve reflection methods early, so tests would fail fast should anything is changed in interfaces/classes + static final ResolvedReflection Reflection = new ResolvedReflection(); + + static class ResolvedReflection { + + final RecordConsumerMethods RecordConsumer = new RecordConsumerMethods(); + final ByteBuddyMessageWritersMethods ByteBuddyProto3FastMessageWriter = + new ByteBuddyMessageWritersMethods(); + final FieldWriterMethods FieldWriter = new FieldWriterMethods(); + + static class RecordConsumerMethods { + final Method startField = + ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "startField", String.class, int.class); + final Method endField = + ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "endField", String.class, int.class); + final Method startGroup = ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "startGroup"); + final Method endGroup = ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "endGroup"); + final Method addInteger = + ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "addInteger", int.class); + final Method addLong = ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "addLong", long.class); + final Method addBoolean = + ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "addBoolean", boolean.class); + final Method addBinary = + ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "addBinary", Binary.class); + final Method addFloat = ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "addFloat", float.class); + final Method addDouble = + ReflectionUtil.getDeclaredMethod(RecordConsumer.class, "addDouble", double.class); + + final Map, Method> PRIMITIVES = initPrimitives(); + + private Map, Method> initPrimitives() { + Map, Method> m = new HashMap<>(); + m.put(int.class, addInteger); + m.put(long.class, addLong); + m.put(boolean.class, addBoolean); + m.put(float.class, addFloat); + m.put(double.class, addDouble); + return Collections.unmodifiableMap(m); + } + + private RecordConsumerMethods() {} + } + + static class ByteBuddyMessageWritersMethods { + final Method getRecordConsumer = ReflectionUtil.getDeclaredMethod( + WriteSupport.ByteBuddyMessageWriters.class, "getRecordConsumer"); + final Method enumNameNumberPairs = ReflectionUtil.getDeclaredMethod( + WriteSupport.ByteBuddyMessageWriters.class, "enumNameNumberPairs", String.class); + + private ByteBuddyMessageWritersMethods() {} + } + + static class FieldWriterMethods { + final Method writeRawValue = ReflectionUtil.getDeclaredMethod( + ProtoWriteSupport.FieldWriter.class, "writeRawValue", Object.class); + } + + private ResolvedReflection() {} + } + + static class Codegen { + public static StackManipulation incIntVar(LocalVar var, int inc) { + int offset = var.offset(); + return new StackManipulation.AbstractBase() { + @Override + public Size apply(MethodVisitor methodVisitor, Implementation.Context implementationContext) { + methodVisitor.visitIincInsn(offset, inc); + return Size.ZERO; + } + }; + } + + private static StackManipulation jumpTo(int jumpInst, Label label) { + return new StackManipulation.AbstractBase() { + @Override + public Size apply(MethodVisitor methodVisitor, Implementation.Context implementationContext) { + methodVisitor.visitJumpInsn(jumpInst, label); + return Size.ZERO; + } + }; + } + + private static StackManipulation visitLabel(Label label) { + return new StackManipulation.AbstractBase() { + @Override + public Size apply(MethodVisitor methodVisitor, Implementation.Context implementationContext) { + methodVisitor.visitLabel(label); + return Size.ZERO; + } + }; + } + + private static Implementation returnVoid() { + return new Implementations(MethodReturn.VOID); + } + + public static StackManipulation castLongToInt() { + return castPrimitive(Opcodes.L2I); + } + + public static StackManipulation castIntToLong() { + return castPrimitive(Opcodes.I2L); + } + + public static StackManipulation castPrimitive(int opcode) { + return new StackManipulation.AbstractBase() { + @Override + public Size apply(MethodVisitor methodVisitor, Implementation.Context implementationContext) { + methodVisitor.visitInsn(opcode); + return Size.ZERO; + } + }; + } + + public static StackManipulation invokeMethod(Method method) { + return MethodInvocation.invoke(new MethodDescription.ForLoadedMethod(method)); + } + + public static StackManipulation invokeProtoMethod( + Class proto3MessageOrBuilderInterface, + String name, + Descriptors.FieldDescriptor fieldDescriptor, + Class... parameters) { + return invokeMethod(ReflectionUtil.getDeclaredMethod( + proto3MessageOrBuilderInterface, fieldDescriptor, name, parameters)); + } + + public static StackManipulation storeRecordConsumer(LocalVar recordConsumerVar) { + return new StackManipulation.Compound( + MethodVariableAccess.loadThis(), + invokeMethod(Reflection.ByteBuddyProto3FastMessageWriter.getRecordConsumer), + recordConsumerVar.store()); + } + } + + static class Implementations implements Implementation { + private final List implementations = new ArrayList<>(); + private final List ongoing = new ArrayList<>(); + + private Implementation compound; + + public Implementations() {} + + public Implementations(StackManipulation... stackManipulations) { + add(stackManipulations); + } + + @Override + public ByteCodeAppender appender(Target implementationTarget) { + return compound.appender(implementationTarget); + } + + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + if (compound != null) { + throw new IllegalStateException(); + } + flushOngoing(); + compound = new Compound(implementations); + return compound.prepare(instrumentedType); + } + + public Implementations add(Implementation... implementations) { + flushOngoing(); + this.implementations.addAll(Arrays.asList(implementations)); + return this; + } + + public Implementations add(ByteCodeAppender... appenders) { + return add(new Simple(appenders)); + } + + public Implementations add(StackManipulation... stackManipulations) { + ongoing.addAll(Arrays.asList(stackManipulations)); + return this; + } + + private void flushOngoing() { + if (!ongoing.isEmpty()) { + implementations.add(new Simple(ongoing.toArray(new StackManipulation[0]))); + ongoing.clear(); + } + } + } + + static class LocalVar implements AutoCloseable { + private final LocalVars vars; + private final TypeDescription typeDescription; + private final Class clazz; + private final int stackSize; + + private int refCount; + + private int offset; + + public LocalVar(Class clazz, TypeDescription typeDescription, LocalVars vars) { + this.clazz = clazz; + this.typeDescription = typeDescription; + this.vars = vars; + this.stackSize = StackSize.of(typeDescription); + } + + public LocalVars vars() { + return vars; + } + + public int offset() { + assertRegistered(); + return offset; + } + + public TypeDescription typeDescription() { + return typeDescription; + } + + public StackManipulation load() { + return MethodVariableAccess.of(typeDescription()).loadFrom(offset()); + } + + public StackManipulation store() { + return MethodVariableAccess.of(typeDescription()).storeAt(offset()); + } + + public Class clazz() { + if (clazz == null) { + throw new IllegalStateException(); + } + return clazz; + } + + private int stackSize() { + return stackSize; + } + + public LocalVar register() { + vars.register(this); + return this; + } + + public LocalVar alias() { + assertRegistered(); + refCount += 1; + return this; + } + + public LocalVar unregister() { + int index = assertRegistered(); + refCount -= 1; + if (refCount == 0) { + if (index != vars.vars.size() - 1) { + throw new IllegalStateException("cannot deregister var " + this + " from " + vars.vars); + } + vars.vars.remove(this); + } + return this; + } + + private int assertRegistered() { + int index = getIndex(); + if (index < 0) { + throw new IllegalStateException("not registered"); + } + return index; + } + + private int getIndex() { + return vars.vars.indexOf(this); + } + + @Override + public void close() { + unregister(); + } + + @Override + public String toString() { + return "LocalVar{" + "vars=" + + vars + ", typeDescription=" + + typeDescription + ", stackSize=" + + stackSize + ", offset=" + + offset + '}'; + } + } + + static class LocalVars { + private final List frame = new ArrayList<>(); + private final List vars = new ArrayList<>(); + private int maxSize; + + public LocalVar register(LocalVar var) { + if (vars.contains(var)) { + throw new IllegalStateException("cannot register var twice: " + var + ", " + vars); + } + int offset = + vars.isEmpty() ? 0 : vars.get(vars.size() - 1).offset + vars.get(vars.size() - 1).stackSize; + vars.add(var); + var.offset = offset; + var.refCount = 1; + + maxSize = Math.max(maxSize, getSize()); + return var; + } + + public StackManipulation frameSame1(Class varOnStack) { + List currTypes = types(); + try { + return new StackManipulation.AbstractBase() { + @Override + public Size apply(MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext + .getFrameGeneration() + .same1(methodVisitor, TypeDescription.ForLoadedType.of(varOnStack), currTypes); + return Size.ZERO; + } + }; + } finally { + this.frame.clear(); + this.frame.addAll(currTypes); + } + } + + public StackManipulation frameEmptyStack() { + List currTypes = types(); + List frame = new ArrayList<>(this.frame); + try { + if (currTypes.size() < frame.size()) { + int commonLength = commonTypesLength(currTypes, frame); + if (commonLength < currTypes.size() || frame.size() - currTypes.size() > 3) { + return new StackManipulation.AbstractBase() { + @Override + public Size apply( + MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext + .getFrameGeneration() + .full(methodVisitor, Collections.emptyList(), currTypes); + return Size.ZERO; + } + }; + } else { + return new StackManipulation.AbstractBase() { + @Override + public Size apply( + MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext + .getFrameGeneration() + .chop(methodVisitor, frame.size() - currTypes.size(), currTypes); + return Size.ZERO; + } + }; + } + } else if (currTypes.size() == frame.size()) { + int commonLength = commonTypesLength(currTypes, frame); + if (commonLength != currTypes.size()) { + return new StackManipulation.AbstractBase() { + @Override + public Size apply( + MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext + .getFrameGeneration() + .full(methodVisitor, Collections.emptyList(), currTypes); + return Size.ZERO; + } + }; + } else { + return new StackManipulation.AbstractBase() { + @Override + public Size apply( + MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext.getFrameGeneration().same(methodVisitor, currTypes); + return Size.ZERO; + } + }; + } + } else { + int commonLength = commonTypesLength(currTypes, frame); + if (commonLength < frame.size() || currTypes.size() - frame.size() > 3) { + return new StackManipulation.AbstractBase() { + @Override + public Size apply( + MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext + .getFrameGeneration() + .full(methodVisitor, Collections.emptyList(), currTypes); + return Size.ZERO; + } + }; + } else { + return new StackManipulation.AbstractBase() { + @Override + public Size apply( + MethodVisitor methodVisitor, Implementation.Context implementationContext) { + implementationContext + .getFrameGeneration() + .append( + methodVisitor, + currTypes.subList(frame.size(), currTypes.size()), + frame); + return Size.ZERO; + } + }; + } + } + } finally { + this.frame.clear(); + this.frame.addAll(currTypes); + } + } + + private int commonTypesLength(List a, List b) { + int len = Math.min(a.size(), b.size()); + for (int i = 0; i < len; i++) { + if (!Objects.equals(a.get(i), b.get(i))) { + return i; + } + } + return len; + } + + public LocalVar register(TypeDescription typeDescription) { + LocalVar var = new LocalVar(null, typeDescription, this); + return register(var); + } + + public LocalVar register(Class clazz) { + LocalVar var = new LocalVar(clazz, TypeDescription.ForLoadedType.of(clazz), this); + return register(var); + } + + public Implementation asImplementation() { + return new Implementation.Simple(new ByteCodeAppender() { + @Override + public Size apply( + MethodVisitor methodVisitor, + Implementation.Context implementationContext, + MethodDescription instrumentedMethod) { + return new Size(0, maxSize); + } + }); + } + + private int getSize() { + int size = 0; + for (LocalVar var : vars) { + size += var.stackSize(); + } + return size; + } + + private List types() { + List types = new ArrayList<>(); + for (LocalVar var : vars) { + types.add(var.typeDescription); + } + return types; + } + } + } + + static class ReflectionUtil { + + static Optional> getMessageOrBuilderInterface( + Class messageClass) { + return Stream.of(messageClass) + .filter(Objects::nonNull) + .filter(ByteBuddyCodeGen::isGeneratedMessage) + .flatMap(x -> Arrays.stream(x.getInterfaces())) + .filter(MessageOrBuilder.class::isAssignableFrom) + .map(x -> (Class) x) + .findFirst(); + } + + static Method getDeclaredMethod(Class clazz, String name, Class... parameterTypes) { + try { + return clazz.getDeclaredMethod(name, parameterTypes); + } catch (NoSuchMethodException e) { + throw new CodeGenException(e); + } + } + + static Method getDeclaredMethod( + Class protoClazz, Descriptors.FieldDescriptor fieldDescriptor, String name, Class... parameters) { + return getDeclaredMethod( + protoClazz, name.replace("{}", getFieldNameForMethod(fieldDescriptor)), parameters); + } + + static Method getDeclaredMethodByName(Class clazz, String name) { + for (Method method : clazz.getDeclaredMethods()) { + if (name.equals(method.getName())) { + return method; + } + } + throw new CodeGenException("no such method on class " + clazz + ": " + name); + } + + static Method getDeclaredMethodByName( + Class clazz, Descriptors.FieldDescriptor fieldDescriptor, String name) { + return getDeclaredMethodByName(clazz, name.replace("{}", getFieldNameForMethod(fieldDescriptor))); + } + + // almost the same as com.google.protobuf.Descriptors.FieldDescriptor#fieldNameToJsonName + // but capitalizing the first letter after each last digit + static String getFieldNameForMethod(Descriptors.FieldDescriptor fieldDescriptor) { + String name = fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.GROUP + ? fieldDescriptor.getMessageType().getName() + : fieldDescriptor.getName(); + final int length = name.length(); + StringBuilder result = new StringBuilder(length); + boolean isNextUpperCase = false; + for (int i = 0; i < length; i++) { + char ch = name.charAt(i); + if (ch == '_') { + isNextUpperCase = true; + } else if ('0' <= ch && ch <= '9') { + isNextUpperCase = true; + result.append(ch); + } else if (isNextUpperCase || i == 0) { + // This closely matches the logic for ASCII characters in: + // http://google3/google/protobuf/descriptor.cc?l=249-251&rcl=228891689 + if ('a' <= ch && ch <= 'z') { + ch = (char) (ch - 'a' + 'A'); + } + result.append(ch); + isNextUpperCase = false; + } else { + result.append(ch); + } + } + return result.toString(); + } + + static Constructor getConstructor(Class clazz, Class... parameterTypes) { + try { + return clazz.getConstructor(parameterTypes); + } catch (NoSuchMethodException e) { + throw new CodeGenException(e); + } + } + + static T newInstance(Constructor constructor, Object... initParams) { + try { + return constructor.newInstance(initParams); + } catch (InstantiationException | IllegalAccessException e) { + throw new CodeGenException(e); + } catch (InvocationTargetException e) { + if (e.getCause() instanceof CodeGenException) { + throw (CodeGenException) e.getCause(); + } + throw new CodeGenException(e.getCause()); + } + } + + static Optional> classForName(String className) { + try { + return Optional.of(Class.forName(className, false, ByteBuddyCodeGen.class.getClassLoader())); + } catch (ClassNotFoundException e) { + return Optional.empty(); + } + } + } + + static class WriteSupport { + // in order to avoid class generation for the same proto descriptors, cache implementations. + private static final Map.MessageWriter>> + WRITERS_CACHE = new MapMaker().weakValues().makeMap(); + + private static final Consumer.MessageWriter> NOOP_WRITER_PATCHER = messageWriter -> {}; + private static final Consumer.MessageWriter> REVERT_WRITER_PATCHER = messageWriter -> { + Queue.FieldWriter> queue = new ArrayDeque<>(); + queue.add(messageWriter); + + while (!queue.isEmpty()) { + ProtoWriteSupport.FieldWriter fw = queue.poll(); + if (fw instanceof ProtoWriteSupport.MessageWriter) { + ((ProtoWriteSupport.MessageWriter) fw) + .setAlternativeMessageWriter(ProtoWriteSupport.MessageFieldsWriter.NOOP); + queue.addAll(Arrays.asList(((ProtoWriteSupport.MessageWriter) fw).fieldWriters)); + } else if (fw instanceof ProtoWriteSupport.ArrayWriter) { + queue.add(((ProtoWriteSupport.ArrayWriter) fw).fieldWriter); + } else if (fw instanceof ProtoWriteSupport.RepeatedWriter) { + queue.add(((ProtoWriteSupport.RepeatedWriter) fw).fieldWriter); + } else if (fw instanceof ProtoWriteSupport.MapWriter) { + queue.add(((ProtoWriteSupport.MapWriter) fw).keyWriter); + queue.add(((ProtoWriteSupport.MapWriter) fw).valueWriter); + } + } + }; + + static class MessageFieldsWritersCacheKey { + private final MessageType rootSchema; + private final Class protoMessage; + private final boolean writeSpecsCompliant; + private final boolean protoReflectionForExtendable; + + MessageFieldsWritersCacheKey( + MessageType rootSchema, + Class protoMessage, + boolean writeSpecsCompliant, + boolean protoReflectionForExtendable) { + this.rootSchema = rootSchema; + this.protoMessage = protoMessage; + this.writeSpecsCompliant = writeSpecsCompliant; + this.protoReflectionForExtendable = protoReflectionForExtendable; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + MessageFieldsWritersCacheKey that = (MessageFieldsWritersCacheKey) o; + return writeSpecsCompliant == that.writeSpecsCompliant + && protoReflectionForExtendable == that.protoReflectionForExtendable + && Objects.equals(rootSchema, that.rootSchema) + && Objects.equals(protoMessage, that.protoMessage); + } + + @Override + public int hashCode() { + return Objects.hash(rootSchema, protoMessage, writeSpecsCompliant, protoReflectionForExtendable); + } + } + + static void tryApplyAlternativeMessageFieldsWriters( + ProtoWriteSupport.MessageWriter rootMessageWriter, + MessageType rootSchema, + Class protoMessage, + Descriptors.Descriptor descriptor, + ProtoWriteSupport.CodegenMode codegenMode) { + + if (!codegenMode.tryCodeGen(protoMessage)) { + return; + } + + MessageFieldsWritersCacheKey cacheKey = new MessageFieldsWritersCacheKey( + rootSchema, + protoMessage, + rootMessageWriter.getProtoWriteSupport().isWriteSpecsCompliant(), + codegenMode.protobufReflectionForExtensions()); + + try { + Consumer.MessageWriter> messageFieldsWriterPatcher = WRITERS_CACHE.computeIfAbsent( + cacheKey, + unused -> createMessageFieldsWriterPatcher( + rootMessageWriter, protoMessage, descriptor, codegenMode)); + messageFieldsWriterPatcher.accept(rootMessageWriter); + } catch (Throwable t) { + if (!codegenMode.ignoreCodeGenException()) { + throw t; + } + REVERT_WRITER_PATCHER.accept(rootMessageWriter); + } + } + + private static Consumer.MessageWriter> createMessageFieldsWriterPatcher( + ProtoWriteSupport.MessageWriter rootMessageWriter, + Class protoMessage, + Descriptors.Descriptor descriptor, + ProtoWriteSupport.CodegenMode codegenMode) { + return new ByteBuddyMessageWritersCodeGen(rootMessageWriter, protoMessage, descriptor, codegenMode) + .getPatcher(); + } + + static class Field { + private final FieldScanner fieldScanner; + private final Field parent; + private final ProtoWriteSupport.FieldWriter fieldWriter; + + private final Descriptors.FieldDescriptor fieldDescriptor; // can be null for root MessageWriter + private final Descriptors.Descriptor messageType; // filled for Message fields (incl. Map) + + private final String parquetFieldName; + private final int parquetFieldIndex; + + private Type reflectionType; + private Object codeGenerationBasicType; + private Object codeGenerationKey; + + private List children; + private Field mapKey; + private Field mapValue; + + private Field( + FieldScanner fieldScanner, + Field parent, + ProtoWriteSupport.FieldWriter fieldWriter, + Descriptors.FieldDescriptor fieldDescriptor, + String parquetFieldName, + int parquetFieldIndex) { + this.fieldScanner = fieldScanner; + this.parent = parent; + this.fieldWriter = fieldWriter; + this.fieldDescriptor = fieldDescriptor; + this.messageType = fieldDescriptor.getJavaType() == Descriptors.FieldDescriptor.JavaType.MESSAGE + ? fieldDescriptor.getMessageType() + : null; + this.parquetFieldName = parquetFieldName; + this.parquetFieldIndex = parquetFieldIndex; + } + + private Field( + FieldScanner fieldScanner, + ProtoWriteSupport.MessageWriter messageWriter, + Class protoMessage, + Descriptors.Descriptor messageType) { + this.fieldScanner = fieldScanner; + this.parent = null; + this.fieldWriter = messageWriter; + this.fieldDescriptor = null; + this.messageType = messageType; + this.reflectionType = protoMessage; + this.parquetFieldName = null; + this.parquetFieldIndex = -1; + } + + public String getParquetFieldName() { + return parquetFieldName; + } + + public int getParquetFieldIndex() { + return parquetFieldIndex; + } + + public Field getParent() { + return parent; + } + + @Override + public String toString() { + List path = new ArrayList<>(); + Field p = this; + while (p != null) { + path.add(p.getParquetFieldName()); + p = p.getParent(); + } + Collections.reverse(path); + return String.valueOf(path); + } + + public Descriptors.Descriptor getMessageType() { + return messageType; + } + + // helps codegen to deal with particular java getter for a proto field + public Type getReflectionType() { + if (reflectionType == null) { + reflectionType = initReflectionType(); + } + return reflectionType; + } + + public Class getMessageOrBuilderInterface() { + if (!isProtoMessage()) { + throw new CodeGenException(); + } + return ReflectionUtil.getMessageOrBuilderInterface((Class) getReflectionType()) + .get(); + } + + public boolean isList() { + return !isMap() && fieldDescriptor != null && fieldDescriptor.isRepeated(); + } + + private Type initReflectionType() { + // parent is always not null here + if (isMap()) { + return initMapReflectionType(); + } else if (parent.isMap()) { + MapReflectionType mapReflectionType = (MapReflectionType) parent.getReflectionType(); + return fieldDescriptor.getIndex() == 0 ? mapReflectionType.key() : mapReflectionType.value(); + } else { + return initRegularFieldReflectionType(); + } + } + + private Type initRegularFieldReflectionType() { + Class clazz; + Class parentProtoMessage = (Class) parent.getReflectionType(); + if (fieldDescriptor.isRepeated()) { + clazz = ReflectionUtil.getDeclaredMethod(parentProtoMessage, fieldDescriptor, "get{}", int.class) + .getReturnType(); + } else { + clazz = ReflectionUtil.getDeclaredMethod(parentProtoMessage, fieldDescriptor, "get{}") + .getReturnType(); + } + if (fieldDescriptor.getJavaType() == Descriptors.FieldDescriptor.JavaType.ENUM) { + return new EnumReflectionType(clazz, fieldDescriptor); + } + return clazz; + } + + private Type initMapReflectionType() { + Class parentProtoMessage = (Class) parent.getReflectionType(); + Method method = + ReflectionUtil.getDeclaredMethodByName(parentProtoMessage, fieldDescriptor, "get{}OrThrow"); + Descriptors.FieldDescriptor valueFieldDescriptor = + fieldDescriptor.getMessageType().getFields().get(1); + Type valueType; + if (valueFieldDescriptor.getJavaType() == Descriptors.FieldDescriptor.JavaType.ENUM) { + valueType = new EnumReflectionType(method.getReturnType(), valueFieldDescriptor); + } else { + valueType = method.getReturnType(); + } + return new MapReflectionType(method.getParameterTypes()[0], valueType); + } + + // helps codegen to identify unique methods and supporting fields to write messages, map entries and enums + public Object getCodeGenerationElementKey() { + if (codeGenerationKey == null) { + codeGenerationKey = initCodeGenerationKey(); + } + return codeGenerationKey; + } + + private Object initCodeGenerationKey() { + if (isMessage() || (isMap() && mapValue().isMessage())) { + List key = new ArrayList<>(); + key.add(getCodeGenerationBasicType()); + for (Field child : getChildren()) { + if (child.isProtoMessage() + || (child.isMap() && child.mapValue().isProtoMessage())) { + key.add(child.getCodeGenerationElementKey()); + } + } + return key; + } + if (isBinaryMessage() || (isMap() && mapValue().isBinaryMessage())) { + return getCodeGenerationBasicType(); + } + if (isMap()) { + return getCodeGenerationBasicType(); + } + if (isEnum()) { + // for enums extra fields have to be prepared and their content depend on Enum type itself, not on + // the declaring message type + return getFieldDescriptor().getEnumType(); + } + throw new CodeGenException("no code generation is allowed for this field"); + } + + private Object getCodeGenerationBasicType() { + if (codeGenerationBasicType == null) { + codeGenerationBasicType = initCodeGenerationBasicType(); + } + return codeGenerationBasicType; + } + + private Object initCodeGenerationBasicType() { + if (isMap()) { + Object keyType = mapKey().getCodeGenerationBasicType(); + Object valueType = mapValue().getCodeGenerationBasicType(); + return Arrays.asList(keyType, valueType); + } else if (isProtoMessage()) { + return Arrays.asList(isBinaryMessage() ? "binary_message" : "message", getMessageType()); + } else if (isEnum()) { + return Arrays.asList( + getFieldDescriptor().getEnumType(), + getFieldDescriptor().legacyEnumFieldTreatedAsClosed()); + } else { + return getFieldDescriptor().getJavaType(); + } + } + + public Descriptors.FieldDescriptor getFieldDescriptor() { + return fieldDescriptor; + } + + public List getChildren() { + if (children == null) { + children = initChildren(); + } + return children; + } + + private List initChildren() { + if (isMessage()) { + ProtoWriteSupport.FieldWriter[] fieldWriters = getMessageWriter().fieldWriters; + return resolveChildFields(fieldWriters); + } else if (isMap()) { + return Arrays.asList(mapKey(), mapValue()); + } else { + return Collections.emptyList(); + } + } + + private List resolveChildFields(ProtoWriteSupport.FieldWriter[] fieldWriters) { + List fieldDescriptors = messageType.getFields(); + int fieldsCount = fieldWriters.length; + List result = new ArrayList<>(fieldsCount); + for (int i = 0; i < fieldsCount; i++) { + result.add(resolveField(fieldWriters[i], fieldDescriptors.get(i))); + } + return result; + } + + public boolean isMessage() { + // this does not include Map and Message fields written as binary + return isProtoMessage() && fieldWriter instanceof ProtoWriteSupport.MessageWriter; + } + + public boolean isBinaryMessage() { + return isProtoMessage() && fieldWriter instanceof ProtoWriteSupport.BinaryWriter; + } + + public boolean isProtoMessage() { + return !isMap() && !isProtoWrapper() && messageType != null; + } + + private ProtoWriteSupport.MessageWriter getMessageWriter() { + if (!isMessage()) { + throw new CodeGenException(); + } + return (ProtoWriteSupport.MessageWriter) fieldWriter; + } + + public boolean isFieldWriterFallbackTransition() { + // track only those 'protobuf reflection writers that are children of codegen writers' + Field parent = getParent(); + while (parent != null) { + if (parent.isMessage()) { + break; + } + parent = parent.getParent(); + } + + return (parent != null && !parent.isFieldWriterFallbackTransition() && isFieldWriterFallback()) + || (parent == null && isFieldWriterFallback()); + } + + private boolean isFieldWriterFallback() { + if (isBinaryMessage()) return true; + if (isMessage() && fieldScanner.isFieldWriterFallbackForExtendable() && isExtendableMessage()) + return true; + return false; + } + + private boolean isExtendableMessage() { + if (!isMessage()) { + throw new CodeGenException(); + } + Class protoMessage = (Class) getReflectionType(); + return ByteBuddyCodeGen.isExtendableMessage(protoMessage); + } + + public boolean isMap() { + // fieldDescriptor is null for root message which is message, not map. + return fieldDescriptor != null && fieldDescriptor.isMapField(); + } + + private Field mapKey() { + if (mapKey == null) { + mapKey = initMapKey(); + } + return mapKey; + } + + private Field initMapKey() { + if (!isMap()) { + throw new CodeGenException(); + } + if (fieldWriter instanceof ProtoWriteSupport.MessageWriter) { + return resolveField( + ((ProtoWriteSupport.MessageWriter) fieldWriter).fieldWriters[0], + messageType.getFields().get(0)); + } else if (fieldWriter instanceof ProtoWriteSupport.MapWriter) { + return resolveField( + ((ProtoWriteSupport.MapWriter) fieldWriter).keyWriter, + messageType.getFields().get(0)); + } else { + throw new CodeGenException(); + } + } + + private Field mapValue() { + if (mapValue == null) { + mapValue = initMapValue(); + } + return mapValue; + } + + private Field initMapValue() { + if (!isMap()) { + throw new CodeGenException(); + } + if (fieldWriter instanceof ProtoWriteSupport.MessageWriter) { + return resolveField( + ((ProtoWriteSupport.MessageWriter) fieldWriter).fieldWriters[1], + messageType.getFields().get(1)); + } else if (fieldWriter instanceof ProtoWriteSupport.MapWriter) { + return resolveField( + ((ProtoWriteSupport.MapWriter) fieldWriter).valueWriter, + messageType.getFields().get(1)); + } else { + throw new CodeGenException(); + } + } + + public boolean isEnum() { + return fieldWriter instanceof ProtoWriteSupport.EnumWriter; + } + + private Field resolveField( + ProtoWriteSupport.FieldWriter fieldWriter, Descriptors.FieldDescriptor fieldDescriptor) { + return resolveField(fieldWriter, fieldDescriptor, fieldWriter); + } + + private Field resolveField( + ProtoWriteSupport.FieldWriter fieldWriter, + Descriptors.FieldDescriptor fieldDescriptor, + ProtoWriteSupport.FieldWriter parquetFieldInfo) { + if (fieldWriter instanceof ProtoWriteSupport.ArrayWriter) { + return resolveField( + ((ProtoWriteSupport.ArrayWriter) fieldWriter).fieldWriter, fieldDescriptor, fieldWriter); + } else if (fieldWriter instanceof ProtoWriteSupport.RepeatedWriter) { + return resolveField( + ((ProtoWriteSupport.RepeatedWriter) fieldWriter).fieldWriter, + fieldDescriptor, + fieldWriter); + } else { + return new Field( + fieldScanner, + this, + fieldWriter, + fieldDescriptor, + parquetFieldInfo.fieldName, + parquetFieldInfo.index); + } + } + + public boolean isOptional() { + return !isMap() && !isList() && fieldDescriptor != null && fieldDescriptor.hasPresence(); + } + + public boolean isPrimitive() { + switch (fieldDescriptor.getJavaType()) { + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case BOOLEAN: + return true; + default: + return false; + } + } + + public boolean isBinary() { + return fieldDescriptor.getJavaType() == Descriptors.FieldDescriptor.JavaType.BYTE_STRING; + } + + public boolean isString() { + return fieldDescriptor.getJavaType() == Descriptors.FieldDescriptor.JavaType.STRING; + } + + public boolean isProtoWrapper() { + return fieldWriter instanceof ProtoWriteSupport.BytesValueWriter + || fieldWriter instanceof ProtoWriteSupport.StringValueWriter + || fieldWriter instanceof ProtoWriteSupport.BoolValueWriter + || fieldWriter instanceof ProtoWriteSupport.UInt32ValueWriter + || fieldWriter instanceof ProtoWriteSupport.Int32ValueWriter + || fieldWriter instanceof ProtoWriteSupport.UInt64ValueWriter + || fieldWriter instanceof ProtoWriteSupport.Int64ValueWriter + || fieldWriter instanceof ProtoWriteSupport.FloatValueWriter + || fieldWriter instanceof ProtoWriteSupport.DoubleValueWriter + || fieldWriter instanceof ProtoWriteSupport.TimeWriter + || fieldWriter instanceof ProtoWriteSupport.DateWriter + || fieldWriter instanceof ProtoWriteSupport.TimestampWriter; + } + } + + static final class MapReflectionType implements Type { + private final Type key; + private final Type value; + + public MapReflectionType(Type key, Type value) { + this.key = key; + this.value = value; + } + + public Type key() { + return key; + } + + public Type value() { + return value; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + MapReflectionType that = (MapReflectionType) o; + return Objects.equals(key, that.key) && Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public String toString() { + return "MapReflectionType{" + "key=" + key + ", value=" + value + '}'; + } + } + + static final class EnumReflectionType implements Type { + private final Class clazz; + private final boolean enumSupportsUnknownValues; // determines if Enum actually supports unknown values + private final boolean + fieldSupportsUnknownValues; // only used to help identify which getter to use for enums + + public EnumReflectionType(Class clazz, Descriptors.FieldDescriptor enumField) { + this.clazz = clazz; + this.enumSupportsUnknownValues = !enumField.getEnumType().isClosed(); + this.fieldSupportsUnknownValues = !enumField.legacyEnumFieldTreatedAsClosed(); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + EnumReflectionType that = (EnumReflectionType) o; + return enumSupportsUnknownValues == that.enumSupportsUnknownValues + && fieldSupportsUnknownValues == that.fieldSupportsUnknownValues + && Objects.equals(clazz, that.clazz); + } + + @Override + public int hashCode() { + return Objects.hash(clazz, enumSupportsUnknownValues, fieldSupportsUnknownValues); + } + + @Override + public String toString() { + return "EnumReflectionType{" + "clazz=" + + clazz + ", enumSupportsUnknownValues=" + + enumSupportsUnknownValues + ", fieldSupportsUnknownValues=" + + fieldSupportsUnknownValues + '}'; + } + } + + interface FieldVisitor { + void visitField(Field field); + } + + static class FieldScanner { + private final boolean fieldWriterFallbackForExtendable; + + private FieldScanner(boolean fieldWriterFallbackForExtendable) { + this.fieldWriterFallbackForExtendable = fieldWriterFallbackForExtendable; + } + + public boolean isFieldWriterFallbackForExtendable() { + return fieldWriterFallbackForExtendable; + } + + public void scan( + ProtoWriteSupport.MessageWriter messageWriter, + Class protoMessage, + Descriptors.Descriptor messageType, + FieldVisitor visitor) { + scan(new Field(this, messageWriter, protoMessage, messageType), visitor); + } + + public void scan(Field startField, FieldVisitor visitor) { + Queue queue = new ArrayDeque<>(); + queue.add(startField); + + while (!queue.isEmpty()) { + Field field = queue.poll(); + visitor.visitField(field); + queue.addAll(field.getChildren()); + } + } + } + + static class ByteBuddyMessageWritersCodeGen { + private final FieldScanner fieldScanner; + private final Class protoMessage; + private final Descriptors.Descriptor descriptor; + private final ProtoWriteSupport protoWriteSupport; + + private final Map codeGenMessageWriters = new LinkedHashMap<>(); + private final Map mapWriters = new LinkedHashMap<>(); + private final Map fieldWriterFallbacks = new LinkedHashMap<>(); + private final Map enumFields = new LinkedHashMap<>(); + + private DynamicType.Builder classBuilder; + private final Class byteBuddyMessageWritersClass; + + public ByteBuddyMessageWritersCodeGen( + ProtoWriteSupport.MessageWriter messageWriter, + Class protoMessage, + Descriptors.Descriptor descriptor, + ProtoWriteSupport.CodegenMode codegenMode) { + this.protoWriteSupport = messageWriter.getProtoWriteSupport(); + this.fieldScanner = new FieldScanner(codegenMode.protobufReflectionForExtensions()); + this.protoMessage = protoMessage; + this.descriptor = descriptor; + + collectCodeGenElements(messageWriter, protoMessage, descriptor); + + if (mapWriters.isEmpty() && codeGenMessageWriters.isEmpty()) { + byteBuddyMessageWritersClass = null; + return; + } + + classBuilder = new ByteBuddy() + .subclass(ByteBuddyMessageWriters.class) + .name(ByteBuddyMessageWriters.class.getName() + "$Generated$" + + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); + + registerEnumFields(); + registerFallbackFieldWriterFields(); + generateConstructor(); + generateMethods(); + overrideSetFallbackFieldWriters(); + + DynamicType.Unloaded unloaded = classBuilder.make(); + + // use to debug codegen + // try { + // unloaded.saveIn(new java.io.File("generated_debug")); + // } catch (Exception e) { + // } + + byteBuddyMessageWritersClass = unloaded.load( + null, ClassLoadingStrategy.UsingLookup.of(MethodHandles.lookup())) + .getLoaded(); + } + + private void registerFallbackFieldWriterFields() { + for (CodeGenFieldWriterFallback fieldWriterFallback : fieldWriterFallbacks.values()) { + classBuilder = classBuilder.define(fieldWriterFallback.fieldWriter()); + } + } + + private void overrideSetFallbackFieldWriters() { + if (fieldWriterFallbacks.isEmpty()) { + return; + } + classBuilder = classBuilder + .method(ElementMatchers.named("setFallbackFieldWriters")) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar fieldWriters = + localVars.register(ProtoWriteSupport.FieldWriter[].class)) { + for (CodeGenFieldWriterFallback fieldWriterFallback : + fieldWriterFallbacks.values()) { + add( + MethodVariableAccess.loadThis(), + fieldWriters.load(), + IntegerConstant.forValue(fieldWriterFallback.getId()), + ArrayAccess.REFERENCE.load(), + FieldAccess.forField(fieldWriterFallback.fieldWriter()) + .write()); + } + } + } + + add(Codegen.returnVoid()); + } + }); + } + + private void generateMethods() { + for (CodeGenMessageWriter codeGenMessageWriter : codeGenMessageWriters.values()) { + classBuilder = classBuilder + .define(codeGenMessageWriter.getMethodDescription()) + .intercept(new WriteAllFieldsForMessageImplementation(codeGenMessageWriter.getField())); + } + + for (CodeGenMapWriter codeGenMapWriter : mapWriters.values()) { + classBuilder = classBuilder + .define(codeGenMapWriter.writeMapEntry()) + .intercept(new WriteAllFieldsForMapEntryImplementation(codeGenMapWriter.getField())); + } + } + + private void generateConstructor() { + classBuilder = classBuilder + .constructor(ElementMatchers.any()) + .intercept(SuperMethodCall.INSTANCE.andThen(new ByteBuddyMessageWritersConstructor())); + } + + private void registerEnumFields() { + for (CodeGenEnum enumField : enumFields.values()) { + classBuilder = classBuilder.define(enumField.enumNumberPairs()); + classBuilder = classBuilder.define(enumField.enumDescriptor()); + classBuilder = classBuilder.define(enumField.enumValues()); + } + } + + private void collectCodeGenElements( + ProtoWriteSupport.MessageWriter messageWriter, + Class protoMessage, + Descriptors.Descriptor descriptor) { + fieldScanner.scan(messageWriter, protoMessage, descriptor, new FieldVisitor() { + @Override + public void visitField(Field field) { + if (field.isFieldWriterFallback()) { + if (field.isFieldWriterFallbackTransition()) { + addCodeGenElement(field, fieldWriterFallbacks, CodeGenFieldWriterFallback::new); + } + } else if (field.isMessage()) { + addCodeGenElement(field, codeGenMessageWriters, CodeGenMessageWriter::new); + } else if (field.isMap()) { + addCodeGenElement(field, mapWriters, CodeGenMapWriter::new); + } else if (field.isEnum()) { + addCodeGenElement(field, enumFields, CodeGenEnum::new); + } + } + }); + } + + private class ByteBuddyMessageWritersConstructor extends Implementations { + + public ByteBuddyMessageWritersConstructor() { + // final Map enumNameNumberPairs; + // final Descriptors.EnumDescriptor enumDescriptor; + // final List enumValues; + + for (CodeGenEnum enumField : enumFields.values()) { + add( + MethodVariableAccess.loadThis(), + MethodVariableAccess.loadThis(), + new TextConstant(enumField.getEnumTypeFullName()), + Codegen.invokeMethod(Reflection.ByteBuddyProto3FastMessageWriter.enumNameNumberPairs), + FieldAccess.forField(enumField.enumNumberPairs()) + .write()); + + add( + MethodVariableAccess.loadThis(), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(enumField.getEnumClass(), "getDescriptor")), + FieldAccess.forField(enumField.enumDescriptor()).write()); + + add( + MethodVariableAccess.loadThis(), + MethodVariableAccess.loadThis(), + FieldAccess.forField(enumField.enumDescriptor()).read(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Descriptors.EnumDescriptor.class, "getValues")), + ArrayFactory.forType(TypeDescription.Generic.OfNonGenericType.ForLoadedType.of( + Descriptors.EnumValueDescriptor.class)) + .withValues(Collections.emptyList()), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(List.class, "toArray", Object[].class)), + TypeCasting.to( + TypeDescription.ForLoadedType.of(Descriptors.EnumValueDescriptor[].class)), + FieldAccess.forField(enumField.enumValues()).write()); + } + + add(Codegen.returnVoid()); + } + } + + private abstract class FastMessageWriterMethodBase extends Implementations { + final CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + add(localVars.asImplementation()); + return super.prepare(instrumentedType); + } + + abstract class RegularFieldWriterTemplate extends Implementations { + final Field field; + final LocalVar recordConsumerVar; + + RegularFieldWriterTemplate(Field field, LocalVar recordConsumerVar) { + this.field = field; + this.recordConsumerVar = recordConsumerVar; + } + + String getterMethodTemplate() { + return "get{}"; + } + + Implementation fieldGetConvertWrite(LocalVar proto3MessageOrBuilder) { + if (field.isList()) { + Label afterIfCountGreaterThanZero = new Label(); + try (LocalVar countVar = localVars.register(int.class)) { + add( + proto3MessageOrBuilder.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + "get{}Count", + field.getFieldDescriptor()), + countVar.store(), + countVar.load(), + Codegen.jumpTo(Opcodes.IFLE, afterIfCountGreaterThanZero), + recordConsumerVar.load(), + new TextConstant(field.getParquetFieldName()), + IntegerConstant.forValue(field.getParquetFieldIndex()), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + + if (protoWriteSupport.isWriteSpecsCompliant()) { + add( + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.startGroup), + recordConsumerVar.load(), + new TextConstant("list"), + IntegerConstant.forValue(0), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + } + + Label nextIteration = new Label(); + Label afterForLoop = new Label(); + try (LocalVar iVar = localVars.register(int.class)) { + add( + IntegerConstant.forValue(0), + iVar.store(), + Codegen.visitLabel(nextIteration), + localVars.frameEmptyStack(), + iVar.load(), + countVar.load(), + Codegen.jumpTo(Opcodes.IF_ICMPGE, afterForLoop)); + + if (protoWriteSupport.isWriteSpecsCompliant()) { + add( + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.startGroup), + recordConsumerVar.load(), + new TextConstant("element"), + IntegerConstant.forValue(0), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + } + + writeRepeatedRawValue(proto3MessageOrBuilder, iVar); + + if (protoWriteSupport.isWriteSpecsCompliant()) { + add( + recordConsumerVar.load(), + new TextConstant("element"), + IntegerConstant.forValue(0), + Codegen.invokeMethod(Reflection.RecordConsumer.endField), + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.endGroup)); + } + + add(Codegen.incIntVar(iVar, 1), Codegen.jumpTo(Opcodes.GOTO, nextIteration)); + } + + add(Codegen.visitLabel(afterForLoop), localVars.frameEmptyStack()); + + if (protoWriteSupport.isWriteSpecsCompliant()) { + add( + recordConsumerVar.load(), + new TextConstant("list"), + IntegerConstant.forValue(0), + Codegen.invokeMethod(Reflection.RecordConsumer.endField), + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.endGroup)); + } + + add( + recordConsumerVar.load(), + new TextConstant(field.getParquetFieldName()), + IntegerConstant.forValue(field.getParquetFieldIndex()), + Codegen.invokeMethod(Reflection.RecordConsumer.endField)); + } + add(Codegen.visitLabel(afterIfCountGreaterThanZero), localVars.frameEmptyStack()); + } else { + Label afterEndField = new Label(); + if (field.isOptional()) { + add( + proto3MessageOrBuilder.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), "has{}", field.getFieldDescriptor()), + Codegen.jumpTo(Opcodes.IFEQ, afterEndField)); + } + + add( + recordConsumerVar.load(), + new TextConstant(field.fieldWriter.fieldName), + IntegerConstant.forValue(field.fieldWriter.index), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + + writeRawValue(proto3MessageOrBuilder); + + add( + recordConsumerVar.load(), + new TextConstant(field.fieldWriter.fieldName), + IntegerConstant.forValue(field.fieldWriter.index), + Codegen.invokeMethod(Reflection.RecordConsumer.endField)); + + if (field.isOptional()) { + add(Codegen.visitLabel(afterEndField), localVars.frameEmptyStack()); + } + } + + return this; + } + + void loadRepeatedValueOnStack(LocalVar proto3MessageOrBuilder, LocalVar iVar) { + add( + proto3MessageOrBuilder.load(), + iVar.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + getterMethodTemplate(), + field.getFieldDescriptor(), + int.class)); + } + + void writeRepeatedRawValue(LocalVar proto3MessageOrBuilder, LocalVar iVar) { + beforeLoadValueOnStack(); + loadRepeatedValueOnStack(proto3MessageOrBuilder, iVar); + convertRawValueAndWrite(); + afterConvertRawValue(); + } + + Implementation writeFromVar(LocalVar var) { + add( + recordConsumerVar.load(), + new TextConstant(field.getParquetFieldName()), + IntegerConstant.forValue(field.getParquetFieldIndex()), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + beforeLoadValueOnStack(); + add(var.load()); + convertRawValueAndWrite(); + afterConvertRawValue(); + add( + recordConsumerVar.load(), + new TextConstant(field.getParquetFieldName()), + IntegerConstant.forValue(field.getParquetFieldIndex()), + Codegen.invokeMethod(Reflection.RecordConsumer.endField)); + return this; + } + + void loadSingleValueOnStack(LocalVar proto3MessageOrBuilder) { + add( + proto3MessageOrBuilder.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + getterMethodTemplate(), + field.getFieldDescriptor())); + } + + void beforeLoadValueOnStack() { + add(recordConsumerVar.load()); + } + + void afterConvertRawValue() {} + + void writeRawValue(LocalVar proto3MessageOrBuilder) { + beforeLoadValueOnStack(); + loadSingleValueOnStack(proto3MessageOrBuilder); + convertRawValueAndWrite(); + afterConvertRawValue(); + } + + abstract void convertRawValueAndWrite(); + } + + private Implementation writePrimitiveField( + LocalVar proto3MessageOrBuilder, LocalVar recordConsumerVar, Field field) { + return new PrimitiveFieldWriter(field, recordConsumerVar) + .fieldGetConvertWrite(proto3MessageOrBuilder); + } + + private Implementation writeBinaryField( + LocalVar proto3MessageOrBuilder, LocalVar recordConsumerVar, Field field) { + return new BinaryFieldWriter(field, recordConsumerVar).fieldGetConvertWrite(proto3MessageOrBuilder); + } + + private Implementation writeStringField( + LocalVar proto3MessageOrBuilder, LocalVar recordConsumerVar, Field field) { + return new StringFieldWriter(field, recordConsumerVar).fieldGetConvertWrite(proto3MessageOrBuilder); + } + + private Implementation writeProtoWrapperField( + LocalVar proto3MessageOrBuilder, LocalVar recordConsumerVar, Field field) { + return new ProtoWrapperFieldWriter(field, recordConsumerVar) + .fieldGetConvertWrite(proto3MessageOrBuilder); + } + + private Implementation writeEnumField( + LocalVar proto3MessageOrBuilder, LocalVar recordConsumerVar, Field field) { + return new EnumFieldWriter(field, recordConsumerVar).fieldGetConvertWrite(proto3MessageOrBuilder); + } + + private Implementation writeMessageField( + LocalVar proto3MessageOrBuilder, LocalVar recordConsumerVar, Field field) { + return new MessageFieldWriter(field, recordConsumerVar) + .fieldGetConvertWrite(proto3MessageOrBuilder); + } + + class PrimitiveFieldWriter extends RegularFieldWriterTemplate { + public PrimitiveFieldWriter(Field field, LocalVar recordConsumerVar) { + super(field, recordConsumerVar); + } + + @Override + void convertRawValueAndWrite() { + add(Codegen.invokeMethod( + Reflection.RecordConsumer.PRIMITIVES.get((Class) field.getReflectionType()))); + } + } + + class MessageFieldWriter extends RegularFieldWriterTemplate { + + MessageFieldWriter(Field field, LocalVar recordConsumerVar) { + super(field, recordConsumerVar); + } + + String getterMethodTemplate() { + return "get{}OrBuilder"; + } + + @Override + void convertRawValueAndWrite() { + if (!field.isFieldWriterFallbackTransition()) { + CodeGenMessageWriter codeGenMessageWriter = + codeGenMessageWriters.get(field.getCodeGenerationElementKey()); + if (codeGenMessageWriter == null) { + throw new CodeGenException("field: " + field); + } + MethodDescription methodDescription = codeGenMessageWriter.getMethodDescription(); + add(MethodInvocation.invoke(methodDescription)); + } else { + add(Codegen.invokeMethod(Reflection.FieldWriter.writeRawValue)); + } + } + + @Override + void loadRepeatedValueOnStack(LocalVar proto3MessageOrBuilder, LocalVar iVar) { + add( + proto3MessageOrBuilder.load(), + iVar.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + getterMethodTemplate(), + field.getFieldDescriptor(), + int.class)); + } + + @Override + void beforeLoadValueOnStack() { + if (!field.isFieldWriterFallbackTransition()) { + startGroup(); + add(MethodVariableAccess.loadThis()); + } else { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(fieldWriterFallbacks + .get(field.getCodeGenerationElementKey()) + .fieldWriter()) + .read()); + } + } + + @Override + void loadSingleValueOnStack(LocalVar proto3MessageOrBuilder) { + add( + proto3MessageOrBuilder.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + getterMethodTemplate(), + field.getFieldDescriptor())); + } + + @Override + void afterConvertRawValue() { + if (!field.isFieldWriterFallbackTransition()) { + endGroup(); + } + } + + void startGroup() { + add(recordConsumerVar.load(), Codegen.invokeMethod(Reflection.RecordConsumer.startGroup)); + } + + void endGroup() { + add(recordConsumerVar.load(), Codegen.invokeMethod(Reflection.RecordConsumer.endGroup)); + } + + Implementation writeMessageFieldsInternal(LocalVar proto3MessageOrBuilder) { + + if (!field.isFieldWriterFallbackTransition()) { + for (Field child : field.getChildren()) { + if (child.isProtoMessage()) { + add(writeMessageField(proto3MessageOrBuilder, recordConsumerVar, child)); + } else if (child.isMap()) { + add(writeMapField(child, proto3MessageOrBuilder)); + } else { + add(writeNonMessageRegularField(child, proto3MessageOrBuilder)); + } + } + } else { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(fieldWriterFallbacks + .get(field.getCodeGenerationElementKey()) + .fieldWriter()) + .read(), + proto3MessageOrBuilder.load(), + Codegen.invokeMethod(Reflection.FieldWriter.writeRawValue)); + } + + return this; + } + + private Implementation writeMapField(Field field, LocalVar proto3MessageOrBuilder) { + return new Implementations() { + { + CodeGenMapWriter codeGenMapWriter = mapWriters.get(field.getCodeGenerationElementKey()); + MethodDescription methodDescription = codeGenMapWriter.writeMapEntry(); + + Class[] parameters = codeGenMapWriter.writeMapEntryParameters(); + + TypeDescription keyType = TypeDescription.ForLoadedType.of(parameters[0]) + .asBoxed(); + TypeDescription valueType = TypeDescription.ForLoadedType.of(parameters[1]) + .asBoxed(); + + Label after = new Label(); + add( + proto3MessageOrBuilder.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + "get{}Count", + field.getFieldDescriptor()), + Codegen.jumpTo(Opcodes.IFLE, after), + recordConsumerVar.load(), + new TextConstant(field.getParquetFieldName()), + IntegerConstant.forValue(field.getParquetFieldIndex()), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + + if (protoWriteSupport.isWriteSpecsCompliant()) { + add( + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.startGroup), + recordConsumerVar.load(), + new TextConstant("key_value"), + IntegerConstant.forValue(0), + Codegen.invokeMethod(Reflection.RecordConsumer.startField)); + } + + add( + proto3MessageOrBuilder.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + codeGenMapWriter.getter(), + field.getFieldDescriptor()), + MethodVariableAccess.loadThis()); + + add(new StackManipulation.Simple(new StackManipulation.Simple.Dispatcher() { + @Override + public StackManipulation.Size apply( + MethodVisitor methodVisitor, Context implementationContext) { + methodVisitor.visitInvokeDynamicInsn( + "accept", + "(" + + classBuilder + .toTypeDescription() + .getDescriptor() + ")Ljava/util/function/BiConsumer;", + new Handle( + Opcodes.H_INVOKESTATIC, + "java/lang/invoke/LambdaMetafactory", + "metafactory", + "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", + false), + JavaConstantValue.Visitor.INSTANCE.onMethodType( + JavaConstant.MethodType.of( + void.class, Object.class, Object.class)), + new Handle( + Opcodes.H_INVOKEVIRTUAL, + classBuilder + .toTypeDescription() + .getInternalName(), + methodDescription.getInternalName(), + methodDescription.getDescriptor(), + false), + JavaConstantValue.Visitor.INSTANCE.onMethodType( + JavaConstant.MethodType.of( + TypeDescription.ForLoadedType.of(void.class), + keyType, + valueType))); + return StackManipulation.Size.ZERO; + } + })); + add(Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(Map.class, "forEach", BiConsumer.class))); + + if (protoWriteSupport.isWriteSpecsCompliant()) { + add( + recordConsumerVar.load(), + new TextConstant("key_value"), + IntegerConstant.forValue(0), + Codegen.invokeMethod(Reflection.RecordConsumer.endField), + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.endGroup)); + } + + add( + recordConsumerVar.load(), + new TextConstant(field.getParquetFieldName()), + IntegerConstant.forValue(field.getParquetFieldIndex()), + Codegen.invokeMethod(Reflection.RecordConsumer.endField)); + + add(Codegen.visitLabel(after), localVars.frameEmptyStack()); + } + }; + } + + protected Implementation writeNonMessageRegularField(Field field, LocalVar proto3MessageOrBuilder) { + if (field.isPrimitive()) { + return writePrimitiveField(proto3MessageOrBuilder, recordConsumerVar, field); + } else if (field.isBinary()) { + return writeBinaryField(proto3MessageOrBuilder, recordConsumerVar, field); + } else if (field.isString()) { + return writeStringField(proto3MessageOrBuilder, recordConsumerVar, field); + } else if (field.isProtoWrapper()) { + return writeProtoWrapperField(proto3MessageOrBuilder, recordConsumerVar, field); + } else if (field.isEnum()) { + return writeEnumField(proto3MessageOrBuilder, recordConsumerVar, field); + } + throw new CodeGenException("field: " + field); + } + } + + class BinaryFieldWriter extends RegularFieldWriterTemplate { + BinaryFieldWriter(Field field, LocalVar recordConsumerVar) { + super(field, recordConsumerVar); + } + /* + ByteString byteString = (ByteString) value; + Binary binary = Binary.fromConstantByteArray(byteString.toByteArray()); + recordConsumer.addBinary(binary); + */ + + @Override + void convertRawValueAndWrite() { + add( + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(ByteString.class, "toByteArray")), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Binary.class, "fromConstantByteArray", byte[].class)), + Codegen.invokeMethod(Reflection.RecordConsumer.addBinary)); + } + } + + class StringFieldWriter extends RegularFieldWriterTemplate { + StringFieldWriter(Field field, LocalVar recordConsumerVar) { + super(field, recordConsumerVar); + } + + /* + Binary binaryString = Binary.fromString((String) value); + recordConsumer.addBinary(binaryString); + */ + @Override + void convertRawValueAndWrite() { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(Binary.class, "fromString", String.class)), + Codegen.invokeMethod(Reflection.RecordConsumer.addBinary)); + } + } + + class ProtoWrapperFieldWriter extends RegularFieldWriterTemplate { + ProtoWrapperFieldWriter(Field field, LocalVar recordConsumerVar) { + super(field, recordConsumerVar); + } + + @Override + void convertRawValueAndWrite() { + ProtoWriteSupport.FieldWriter fieldWriter = field.fieldWriter; + if (fieldWriter instanceof ProtoWriteSupport.BytesValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(BytesValue.class, "getValue")), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(ByteString.class, "toByteArray")), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Binary.class, "fromConstantByteArray", byte[].class)), + Codegen.invokeMethod(Reflection.RecordConsumer.addBinary)); + } else if (fieldWriter instanceof ProtoWriteSupport.StringValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(StringValue.class, "getValue")), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(Binary.class, "fromString", String.class)), + Codegen.invokeMethod(Reflection.RecordConsumer.addBinary)); + } else if (fieldWriter instanceof ProtoWriteSupport.BoolValueWriter) { + add( + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(BoolValue.class, "getValue")), + Codegen.invokeMethod(Reflection.RecordConsumer.addBoolean)); + } else if (fieldWriter instanceof ProtoWriteSupport.UInt32ValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(UInt32Value.class, "getValue")), + Codegen.castIntToLong(), + Codegen.invokeMethod(Reflection.RecordConsumer.addLong)); + } else if (fieldWriter instanceof ProtoWriteSupport.Int32ValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(Int32Value.class, "getValue")), + Codegen.invokeMethod(Reflection.RecordConsumer.addInteger)); + } else if (fieldWriter instanceof ProtoWriteSupport.UInt64ValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(UInt64Value.class, "getValue")), + Codegen.invokeMethod(Reflection.RecordConsumer.addLong)); + } else if (fieldWriter instanceof ProtoWriteSupport.Int64ValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(Int64Value.class, "getValue")), + Codegen.invokeMethod(Reflection.RecordConsumer.addLong)); + } else if (fieldWriter instanceof ProtoWriteSupport.FloatValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(FloatValue.class, "getValue")), + Codegen.invokeMethod(Reflection.RecordConsumer.addFloat)); + } else if (fieldWriter instanceof ProtoWriteSupport.DoubleValueWriter) { + add( + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(DoubleValue.class, "getValue")), + Codegen.invokeMethod(Reflection.RecordConsumer.addDouble)); + } else if (fieldWriter instanceof ProtoWriteSupport.TimeWriter) { + try (LocalVar timeOfDay = localVars.register(TimeOfDay.class)) { + add( + timeOfDay.store(), + timeOfDay.load(), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(TimeOfDay.class, "getHours")), + timeOfDay.load(), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(TimeOfDay.class, "getMinutes")), + timeOfDay.load(), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(TimeOfDay.class, "getSeconds")), + timeOfDay.load(), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(TimeOfDay.class, "getNanos")), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + LocalTime.class, "of", int.class, int.class, int.class, int.class)), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(LocalTime.class, "toNanoOfDay")), + Codegen.invokeMethod(Reflection.RecordConsumer.addLong)); + } + } else if (fieldWriter instanceof ProtoWriteSupport.DateWriter) { + try (LocalVar date = localVars.register(Date.class)) { + add( + date.store(), + date.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(Date.class, "getYear")), + date.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(Date.class, "getMonth")), + date.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(Date.class, "getDay")), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + LocalDate.class, "of", int.class, int.class, int.class)), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(LocalDate.class, "toEpochDay")), + Codegen.castLongToInt(), + Codegen.invokeMethod(Reflection.RecordConsumer.addInteger)); + } + } else if (fieldWriter instanceof ProtoWriteSupport.TimestampWriter) { + add( + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Timestamps.class, "toNanos", Timestamp.class)), + Codegen.invokeMethod(Reflection.RecordConsumer.addLong)); + } else { + throw new IllegalStateException(); + } + } + } + + class EnumFieldWriter extends RegularFieldWriterTemplate { + + EnumFieldWriter(Field field, LocalVar recordConsumerVar) { + super(field, recordConsumerVar); + } + + String getterMethodTemplate() { + return "get{}" + (supportsUnknownValues() ? "Value" : ""); + } + + boolean supportsUnknownValues() { + EnumReflectionType enumReflectionType = (EnumReflectionType) field.getReflectionType(); + return enumReflectionType.enumSupportsUnknownValues + && enumReflectionType.fieldSupportsUnknownValues; + } + + @Override + void loadRepeatedValueOnStack(LocalVar proto3MessageOrBuilder, LocalVar iVar) { + add( + proto3MessageOrBuilder.load(), + iVar.load(), + Codegen.invokeProtoMethod( + proto3MessageOrBuilder.clazz(), + getterMethodTemplate(), + field.getFieldDescriptor(), + int.class)); + } + + @Override + void beforeLoadValueOnStack() {} + + @Override + void convertRawValueAndWrite() { + if (supportsUnknownValues()) { + convertRawValueAndWriteWithUnknownValues(); + } else { + convertRawValueAndWriteWithoutUnknownValues(); + } + } + + /** + * int enumNumber = messageOrBuilder.getEnumValue(); + * ProtocolMessageEnum enum_ = forNumber.apply(enumNumber); + * Enum javaEnum = (Enum) enum_; + * Descriptors.EnumValueDescriptor enumValueDescriptor; + * if (javaEnum != null) { + * enumValueDescriptor = enumValues.get(javaEnum.ordinal()); + * } else { + * enumValueDescriptor = enumDescriptor.findValueByNumberCreatingIfUnknown(enumNumber); + * } + */ + void convertRawValueAndWriteWithUnknownValues() { + CodeGenEnum codeGenEnum = enumFields.get(field.getCodeGenerationElementKey()); + Class enumClass = codeGenEnum.clazz; + + try (LocalVar enumNumber = localVars.register(int.class)) { + add( + enumNumber.store(), + enumNumber.load(), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(enumClass, "forNumber", int.class))); + try (LocalVar enumRef = localVars.register(enumClass)) { + add(enumRef.store(), enumRef.load()); + Label ifEnumRefIsNull = new Label(); + Label afterEnumValueResolved = new Label(); + add( + Codegen.jumpTo(Opcodes.IFNULL, ifEnumRefIsNull), + MethodVariableAccess.loadThis(), + FieldAccess.forField(codeGenEnum.enumValues()) + .read(), + enumRef.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(Enum.class, "ordinal")), + ArrayAccess.REFERENCE.load(), + Codegen.jumpTo(Opcodes.GOTO, afterEnumValueResolved), + Codegen.visitLabel(ifEnumRefIsNull), + localVars.frameEmptyStack(), + MethodVariableAccess.loadThis(), + FieldAccess.forField(codeGenEnum.enumDescriptor()) + .read(), + enumNumber.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Descriptors.EnumDescriptor.class, + "findValueByNumberCreatingIfUnknown", + int.class)), + Codegen.visitLabel(afterEnumValueResolved), + localVars.frameSame1(Descriptors.EnumValueDescriptor.class)); + + writeEnumValueDesc(codeGenEnum); + } + } + } + + /** + * Enum javaEnum = messageOrBuilder.getEnum(); + * enumValueDescriptor = enumValues.get(javaEnum.ordinal()); + */ + void convertRawValueAndWriteWithoutUnknownValues() { + CodeGenEnum codeGenEnum = enumFields.get(field.getCodeGenerationElementKey()); + Class enumClass = codeGenEnum.clazz; + + try (LocalVar enumRef = localVars.register(enumClass)) { + add( + enumRef.store(), + MethodVariableAccess.loadThis(), + FieldAccess.forField(codeGenEnum.enumValues()) + .read(), + enumRef.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(Enum.class, "ordinal")), + ArrayAccess.REFERENCE.load()); + + writeEnumValueDesc(codeGenEnum); + } + } + + private void writeEnumValueDesc(CodeGenEnum codeGenEnum) { + try (LocalVar enumValueDesc = localVars.register(Descriptors.EnumValueDescriptor.class)) { + add(enumValueDesc.store()); + try (LocalVar enumValueDescName = localVars.register(String.class)) { + add( + enumValueDesc.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Descriptors.EnumValueDescriptor.class, "getName")), + enumValueDescName.store(), + enumValueDescName.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Binary.class, "fromString", String.class))); + try (LocalVar binary = localVars.register(Binary.class)) { + add( + binary.store(), + recordConsumerVar.load(), + binary.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.addBinary), + MethodVariableAccess.loadThis(), + FieldAccess.forField(codeGenEnum.enumNumberPairs()) + .read(), + enumValueDescName.load(), + enumValueDesc.load(), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Descriptors.EnumValueDescriptor.class, "getNumber")), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Integer.class, "valueOf", int.class)), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod( + Map.class, "putIfAbsent", Object.class, Object.class)), + Removal.SINGLE); + } + } + } + } + } + } + + class WriteAllFieldsForMessageImplementation extends FastMessageWriterMethodBase { + WriteAllFieldsForMessageImplementation(Field field) { + + try (LocalVar thisLocalVar = localVars.register(classBuilder.toTypeDescription())) { + writeMessageFields(field); + } + } + + private void writeMessageFields(Field field) { + Class messageOrBuilderInterface = field.getMessageOrBuilderInterface(); + + try (LocalVar messageOrBuilderArg = localVars.register(messageOrBuilderInterface)) { + localVars.frameEmptyStack(); + + try (LocalVar proto3MessageOrBuilder = messageOrBuilderArg.alias()) { + try (LocalVar recordConsumerVar = localVars.register(RecordConsumer.class)) { + add(Codegen.storeRecordConsumer(recordConsumerVar)); + add(new MessageFieldWriter(field, recordConsumerVar) + .writeMessageFieldsInternal(proto3MessageOrBuilder)); + } + } + add(Codegen.returnVoid()); + } + } + } + + class WriteAllFieldsForMapEntryImplementation extends FastMessageWriterMethodBase { + + WriteAllFieldsForMapEntryImplementation(Field field) { + CodeGenMapWriter codeGenMapWriter = mapWriters.get(field.getCodeGenerationElementKey()); + Class[] methodParameters = codeGenMapWriter.writeMapEntryParameters(); + try (LocalVar thisLocalVar = localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar key = localVars.register(methodParameters[0])) { + try (LocalVar value = localVars.register(methodParameters[1])) { + try (LocalVar recordConsumerVar = localVars.register(RecordConsumer.class)) { + add(Codegen.storeRecordConsumer(recordConsumerVar)); + add( + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.startGroup)); + add(writeFromVar(field.mapKey(), key, recordConsumerVar)); + add(writeFromVar(field.mapValue(), value, recordConsumerVar)); + add( + recordConsumerVar.load(), + Codegen.invokeMethod(Reflection.RecordConsumer.endGroup)); + add(Codegen.returnVoid()); + } + } + } + } + } + + Implementation writeFromVar(Field field, LocalVar val, LocalVar recordConsumer) { + if (field.isEnum()) { + return new EnumFieldWriter(field, recordConsumer).writeFromVar(val); + } else if (field.isProtoMessage()) { + return new MessageFieldWriter(field, recordConsumer).writeFromVar(val); + } else if (field.isString()) { + return new StringFieldWriter(field, recordConsumer).writeFromVar(val); + } else if (field.isBinary()) { + return new BinaryFieldWriter(field, recordConsumer).writeFromVar(val); + } else if (field.isProtoWrapper()) { + return new ProtoWrapperFieldWriter(field, recordConsumer).writeFromVar(val); + } else if (field.isPrimitive()) { + return new PrimitiveFieldWriter(field, recordConsumer).writeFromVar(val); + } + throw new CodeGenException("field: " + field); + } + } + + static class CodeGenElement { + private final int id; + private final Field field; + + public CodeGenElement(int id, Field field) { + this.id = id; + this.field = field; + } + + public Field getField() { + return field; + } + + public int getId() { + return id; + } + } + + class CodeGenMessageWriter extends CodeGenElement { + private final Class messageOrBuilderInterface; + + public CodeGenMessageWriter(int id, Field field) { + super(id, field); + this.messageOrBuilderInterface = field.getMessageOrBuilderInterface(); + } + + public String getMethodName() { + return "writeAllFields$" + getId(); + } + + public Class getMethodParameterType() { + return messageOrBuilderInterface; + } + + public MethodDescription getMethodDescription() { + return new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + getMethodName(), + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.singletonList(TypeDescription.Generic.OfNonGenericType.ForLoadedType.of( + getMethodParameterType())))); + } + } + + class CodeGenEnum extends CodeGenElement { + private final String enumTypeFullName; + private final Class clazz; + + public CodeGenEnum(int id, Field field) { + super(id, field); + enumTypeFullName = field.getFieldDescriptor().getEnumType().getFullName(); + clazz = ((EnumReflectionType) field.getReflectionType()).clazz; + } + + public Class getEnumClass() { + return clazz; + } + + public String getEnumTypeFullName() { + return enumTypeFullName; + } + + // final Map enumNameNumberPairs; + public FieldDescription enumNumberPairs() { + return new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "enumNameNumberPairs$" + getId(), + Modifier.PRIVATE | Modifier.FINAL, + TypeDescription.Generic.Builder.parameterizedType( + Map.class, String.class, Integer.class) + .build())); + } + + // final Descriptors.EnumDescriptor enumDescriptor + public FieldDescription enumDescriptor() { + return new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "enumDescriptor$" + getId(), + Modifier.PRIVATE | Modifier.FINAL, + new TypeDescription.Generic.OfNonGenericType.ForLoadedType( + Descriptors.EnumDescriptor.class))); + } + + // final List enumValues + public FieldDescription enumValues() { + return new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "enumValues$" + getId(), + Modifier.PRIVATE | Modifier.FINAL, + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of( + Descriptors.EnumValueDescriptor[].class))); + } + } + + class CodeGenMapWriter extends CodeGenElement { + public CodeGenMapWriter(int id, Field field) { + super(id, field); + } + + public String getMethodName() { + return "writeMapEntry$" + getId(); + } + + public MethodDescription writeMapEntry() { + Class[] parameters = writeMapEntryParameters(); + return new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + getMethodName(), + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Arrays.asList( + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(parameters[0]), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(parameters[1])))); + } + + public Class[] writeMapEntryParameters() { + MapReflectionType mapReflectionType = + (MapReflectionType) getField().getReflectionType(); + Class keyType = (Class) mapReflectionType.key(); + Class valueType = getValueType(mapReflectionType); + return new Class[] {keyType, valueType}; + } + + public String getter() { + MapReflectionType mapReflectionType = + (MapReflectionType) getField().getReflectionType(); + boolean isEnumAndSupportsUnknownValues = false; + if (mapReflectionType.value() instanceof EnumReflectionType) { + EnumReflectionType enumReflectionType = (EnumReflectionType) mapReflectionType.value(); + isEnumAndSupportsUnknownValues = enumReflectionType.enumSupportsUnknownValues + && enumReflectionType.fieldSupportsUnknownValues; + } + return "get{}" + (isEnumAndSupportsUnknownValues ? "Value" : "") + "Map"; + } + + private Class getValueType(MapReflectionType mapReflectionType) { + Class valueType; + if (mapReflectionType.value() instanceof EnumReflectionType) { + EnumReflectionType enumReflectionType = (EnumReflectionType) mapReflectionType.value(); + if (enumReflectionType.enumSupportsUnknownValues + && enumReflectionType.fieldSupportsUnknownValues) { + valueType = int.class; + } else { + valueType = enumReflectionType.clazz; + } + } else { + valueType = (Class) mapReflectionType.value(); + } + return valueType; + } + } + + class CodeGenFieldWriterFallback extends CodeGenElement { + + public CodeGenFieldWriterFallback(int id, Field field) { + super(id, field); + } + + public FieldDescription fieldWriter() { + return new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + getFieldName(), + Modifier.PRIVATE, + // TODO: create more specific FieldWriter + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of( + ProtoWriteSupport.FieldWriter.class))); + } + + public String getFieldName() { + return "fieldWriter$" + getId(); + } + } + + static class GeneratedElementsInfo { + private final FieldScanner fieldScanner; + private final Class protoMessage; + private final Descriptors.Descriptor messageType; + private final Map codeGenMessageWriters; + private final Map fallbackFieldWriters; + + public GeneratedElementsInfo( + FieldScanner fieldScanner, + Class protoMessage, + Descriptors.Descriptor messageType, + Map codeGenMessageWriters, + Map fallbackFieldWriters) { + this.fieldScanner = fieldScanner; + this.messageType = messageType; + this.protoMessage = protoMessage; + this.codeGenMessageWriters = codeGenMessageWriters; + this.fallbackFieldWriters = fallbackFieldWriters; + } + + public void scan(ProtoWriteSupport.MessageWriter messageWriter, FieldVisitor fieldVisitor) { + fieldScanner.scan(messageWriter, protoMessage, messageType, fieldVisitor); + } + } + + private GeneratedElementsInfo getGeneratedElementsInfo( + Class generatedClass) { + Map codeGenMessageWriters = new LinkedHashMap<>(); + Map protoReflectionMessageWriters = new LinkedHashMap<>(); + + for (Map.Entry key2CodeGenMessageWriterEntry : + this.codeGenMessageWriters.entrySet()) { + codeGenMessageWriters.put( + key2CodeGenMessageWriterEntry.getKey(), + ReflectionUtil.getDeclaredMethod( + generatedClass, + key2CodeGenMessageWriterEntry.getValue().getMethodName(), + key2CodeGenMessageWriterEntry.getValue().getMethodParameterType())); + } + + for (Map.Entry key2CodeGenProtoReflectionMessageWriterEntry : + this.fieldWriterFallbacks.entrySet()) { + protoReflectionMessageWriters.put( + key2CodeGenProtoReflectionMessageWriterEntry.getKey(), + key2CodeGenProtoReflectionMessageWriterEntry + .getValue() + .getId()); + } + + return new GeneratedElementsInfo( + fieldScanner, protoMessage, descriptor, codeGenMessageWriters, protoReflectionMessageWriters); + } + + private static void addCodeGenElement( + Field field, Map registry, BiFunction codeElementConstructor) { + registry.computeIfAbsent( + field.getCodeGenerationElementKey(), + unused -> codeElementConstructor.apply(registry.size(), field)); + } + + public Consumer.MessageWriter> getPatcher() { + if (byteBuddyMessageWritersClass == null) { + return NOOP_WRITER_PATCHER; + } + return new ByteBuddyMessageWritersPatcher( + ReflectionUtil.getConstructor( + byteBuddyMessageWritersClass, + ProtoWriteSupport.MessageWriter.class, + GeneratedElementsInfo.class), + getGeneratedElementsInfo(byteBuddyMessageWritersClass)); + } + } + + static class ByteBuddyMessageWritersPatcher implements Consumer.MessageWriter> { + private final Constructor byteBuddyMessageWritersConstructor; + private final ByteBuddyMessageWritersCodeGen.GeneratedElementsInfo generatedElementsInfo; + + ByteBuddyMessageWritersPatcher( + Constructor byteBuddyMessageWritersConstructor, + ByteBuddyMessageWritersCodeGen.GeneratedElementsInfo generatedElementsInfo) { + this.byteBuddyMessageWritersConstructor = byteBuddyMessageWritersConstructor; + this.generatedElementsInfo = generatedElementsInfo; + } + + @Override + public void accept(ProtoWriteSupport.MessageWriter messageWriter) { + ReflectionUtil.newInstance(byteBuddyMessageWritersConstructor, messageWriter, generatedElementsInfo); + } + } + + // this is subclassed with ByteBuddy, overriding the constructor, setProtoReflectionMessageWriters and adding + // new fields and methods + abstract static class ByteBuddyMessageWriters { + private final ProtoWriteSupport protoWriteSupport; + private final Map fastMessageWriters = new LinkedHashMap<>(); + + public ByteBuddyMessageWriters( + ProtoWriteSupport.MessageWriter rootMessageWriter, + ByteBuddyMessageWritersCodeGen.GeneratedElementsInfo generatedElementsInfo) { + this.protoWriteSupport = rootMessageWriter.getProtoWriteSupport(); + final ProtoWriteSupport.FieldWriter[] fallbackFieldWriters = + new ProtoWriteSupport.FieldWriter[generatedElementsInfo.fallbackFieldWriters.size()]; + + // assign alternative message writers and collect protobuf reflection message writers + generatedElementsInfo.scan(rootMessageWriter, new FieldVisitor() { + @Override + public void visitField(Field field) { + if (field.isFieldWriterFallback()) { + if (field.isFieldWriterFallbackTransition()) { + Object key = field.getCodeGenerationElementKey(); + int id = generatedElementsInfo.fallbackFieldWriters.get(key); + if (fallbackFieldWriters[id] == null) { + fallbackFieldWriters[id] = field.fieldWriter; + } + } + } else if (field.isMessage()) { + Object key = field.getCodeGenerationElementKey(); + Method method = generatedElementsInfo.codeGenMessageWriters.get(key); + field.getMessageWriter().setAlternativeMessageWriter(getFastMessageWriter(method)); + } + } + }); + + for (ProtoWriteSupport.FieldWriter fieldWriter : fallbackFieldWriters) { + if (fieldWriter == null) { + throw new CodeGenException(); + } + } + setFallbackFieldWriters(fallbackFieldWriters); + } + + // the implementation needs to assign the passed array to fields + public void setFallbackFieldWriters(ProtoWriteSupport.FieldWriter[] fieldWriters) {} + + // used from the generated methods to load record consumer in a local variable + public RecordConsumer getRecordConsumer() { + return protoWriteSupport.getRecordConsumer(); + } + + // used from the constructor, when assigning the maps for enums + public Map enumNameNumberPairs(String enumTypeFullName) { + return protoWriteSupport.getProtoEnumBookKeeper().get(enumTypeFullName); + } + + public ProtoWriteSupport.MessageFieldsWriter getFastMessageWriter(Method method) { + return fastMessageWriters.computeIfAbsent(method, k -> { + MethodHandles.Lookup lookup = MethodHandles.lookup(); + Class messageOrBuilderInterface = method.getParameterTypes()[0]; + try { + Consumer consumer = + (Consumer) LambdaMetafactory.metafactory( + lookup, + "accept", + MethodType.methodType(Consumer.class, this.getClass()), + MethodType.methodType(void.class, Object.class), + lookup.unreflect(method), + MethodType.methodType(void.class, messageOrBuilderInterface)) + .getTarget() + .bindTo(this) + .invokeExact(); + return new ProtoWriteSupport.MessageFieldsWriter() { + @Override + public boolean writeAllFields(MessageOrBuilder messageOrBuilder) { + if (!messageOrBuilderInterface.isInstance(messageOrBuilder)) { + return false; + } + consumer.accept(messageOrBuilder); + return true; + } + }; + } catch (Throwable e) { + throw new CodeGenException(e); + } + }); + } + } + } +} diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index 637f6fda91..034ef1f18f 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -83,8 +83,11 @@ public class ProtoWriteSupport extends WriteSupport< public static final String PB_UNWRAP_PROTO_WRAPPERS = "parquet.proto.unwrapProtoWrappers"; + public static final String PB_CODEGEN = "parquet.proto.codegen"; + private boolean writeSpecsCompliant = false; private boolean unwrapProtoWrappers = false; + private CodegenMode codegenMode = CodegenMode.SUPPORT_COMPATIBLE; private RecordConsumer recordConsumer; private Class protoMessage; private Descriptor descriptor; @@ -127,6 +130,98 @@ public static void setUnwrapProtoWrappers(Configuration configuration, boolean u configuration.setBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers); } + public static void setCodegenMode(Configuration configuration, CodegenMode codegenMode) { + configuration.setEnum(PB_CODEGEN, codegenMode); + } + + /** + * OFF - avoid any code generation + * SUPPORT - attempt to use code generation where available + * REQUIRED - must use code generation and fail for codegen errors (bugs) + */ + public enum CodegenMode { + OFF { + @Override + public boolean ignoreCodeGenException() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean protobufReflectionForExtensions() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean tryCodeGen(Class protoClass) { + return false; + } + }, + + // use Protobuf reflection for Extendable protos in order to throw UnsupportedOperationException if there is an + // extension field set. + SUPPORT_COMPATIBLE { + @Override + public boolean ignoreCodeGenException() { + return true; + } + + @Override + public boolean protobufReflectionForExtensions() { + return true; + } + + @Override + public boolean tryCodeGen(Class protoClass) { + return ByteBuddyCodeGen.isGeneratedMessage(protoClass) && ByteBuddyCodeGen.isByteBuddyAvailable(false); + } + }, + + // include code generation for Extendable protos, though without writing extension fields and without reporting + // errors. + SUPPORT_ALL { + @Override + public boolean ignoreCodeGenException() { + return SUPPORT_COMPATIBLE.ignoreCodeGenException(); + } + + @Override + public boolean protobufReflectionForExtensions() { + return false; + } + + @Override + public boolean tryCodeGen(Class protoClass) { + return SUPPORT_COMPATIBLE.tryCodeGen(protoClass); + } + }, + + REQUIRED_ALL { + @Override + public boolean ignoreCodeGenException() { + return false; + } + + @Override + public boolean protobufReflectionForExtensions() { + return false; + } + + @Override + public boolean tryCodeGen(Class protoClass) { + if (!ByteBuddyCodeGen.isGeneratedMessage(protoClass)) { + throw new IllegalStateException("protoClass is a GeneratedMessage: " + protoClass); + } + return ByteBuddyCodeGen.isByteBuddyAvailable(true); + } + }; + + public abstract boolean ignoreCodeGenException(); + + public abstract boolean protobufReflectionForExtensions(); + + public abstract boolean tryCodeGen(Class protoClass); + } + /** * Writes Protocol buffer to parquet file. * @@ -180,11 +275,15 @@ public WriteContext init(ParquetConfiguration configuration) { unwrapProtoWrappers = configuration.getBoolean(PB_UNWRAP_PROTO_WRAPPERS, unwrapProtoWrappers); writeSpecsCompliant = configuration.getBoolean(PB_SPECS_COMPLIANT_WRITE, writeSpecsCompliant); + codegenMode = CodegenMode.valueOf(configuration.get(PB_CODEGEN, codegenMode.name())); MessageType rootSchema = new ProtoSchemaConverter(configuration).convert(descriptor); validatedMapping(descriptor, rootSchema); this.messageWriter = new MessageWriter(descriptor, rootSchema); + ByteBuddyCodeGen.WriteSupport.tryApplyAlternativeMessageFieldsWriters( + messageWriter, rootSchema, protoMessage, descriptor, codegenMode); + extraMetaData.put(ProtoReadSupport.PB_DESCRIPTOR, descriptor.toProto().toString()); extraMetaData.put(PB_SPECS_COMPLIANT_WRITE, String.valueOf(writeSpecsCompliant)); extraMetaData.put(PB_UNWRAP_PROTO_WRAPPERS, String.valueOf(unwrapProtoWrappers)); @@ -260,6 +359,8 @@ class MessageWriter extends FieldWriter { final FieldWriter[] fieldWriters; + MessageFieldsWriter alternativeMessageWriter = MessageFieldsWriter.NOOP; + @SuppressWarnings("unchecked") MessageWriter(Descriptor descriptor, GroupType schema) { List fields = descriptor.getFields(); @@ -284,6 +385,10 @@ class MessageWriter extends FieldWriter { } } + ProtoWriteSupport getProtoWriteSupport() { + return ProtoWriteSupport.this; + } + private FieldWriter createWriter(FieldDescriptor fieldDescriptor, Type type) { switch (fieldDescriptor.getJavaType()) { @@ -444,6 +549,10 @@ final void writeField(Object value) { } private void writeAllFields(MessageOrBuilder pb) { + if (alternativeMessageWriter.writeAllFields(pb)) { + return; + } + Descriptor messageDescriptor = pb.getDescriptorForType(); Descriptors.FileDescriptor.Syntax syntax = messageDescriptor.getFile().getSyntax(); @@ -485,6 +594,10 @@ private void writeAllFields(MessageOrBuilder pb) { } } } + + void setAlternativeMessageWriter(MessageFieldsWriter alternativeMessageWriter) { + this.alternativeMessageWriter = alternativeMessageWriter; + } } class ArrayWriter extends FieldWriter { @@ -601,8 +714,8 @@ final void writeRawValue(Object value) { class MapWriter extends FieldWriter { - private final FieldWriter keyWriter; - private final FieldWriter valueWriter; + final FieldWriter keyWriter; + final FieldWriter valueWriter; public MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) { super(); @@ -796,4 +909,33 @@ private FieldWriter unknownType(FieldDescriptor fieldDescriptor) { + fieldDescriptor.getJavaType() + "\"."; throw new InvalidRecordException(exceptionMsg); } + + /** + * A plugin for {@link MessageWriter#writeAllFields(MessageOrBuilder)} that is potentially + * capable to write MessageOrBuilder fields faster. + */ + public interface MessageFieldsWriter { + MessageFieldsWriter NOOP = messageOrBuilder -> false; + + /** + * Performs all the steps that {@link MessageWriter#writeAllFields(MessageOrBuilder)} + * would normally do, but faster. + * @param messageOrBuilder + * @return true if this writer has written fields of the passed messageOrBuilder + * false otherwise + */ + boolean writeAllFields(MessageOrBuilder messageOrBuilder); + } + + RecordConsumer getRecordConsumer() { + return recordConsumer; + } + + Map> getProtoEnumBookKeeper() { + return protoEnumBookKeeper; + } + + boolean isWriteSpecsCompliant() { + return writeSpecsCompliant; + } } From ad6822c08d54f49f08f4ad2991cff9a4137ade2c Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Tue, 14 Jan 2025 00:24:18 +0000 Subject: [PATCH 02/15] run unit tests in different codegen modes --- .../parquet/proto/ByteBuddyCodeGen.java | 6 +- .../parquet/proto/ProtoWriteSupport.java | 10 +- .../proto/ProtoInputOutputFormatTest.java | 62 +++++++++--- .../parquet/proto/ProtoParquetWriterTest.java | 61 ++++++++++-- .../parquet/proto/ProtoWriteSupportTest.java | 96 +++++++++++++++---- 5 files changed, 192 insertions(+), 43 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 1f869c6de9..318723ed86 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -174,7 +174,7 @@ static boolean isByteBuddyAvailable(boolean failIfNot) { return true; } catch (ClassNotFoundException e) { if (failIfNot) { - throw new IllegalStateException("ByteBuddy is not available", e); + throw new UnsupportedOperationException("ByteBuddy is not available", e); } return false; } @@ -1197,6 +1197,10 @@ private Field resolveField( fieldDescriptor, fieldWriter); } else { + if (!Objects.equals(parquetFieldInfo.fieldName, fieldDescriptor.getName())) { + throw new CodeGenException("fields mismatch: parquetFieldInfo: " + parquetFieldInfo.fieldName + + ", fieldDescriptor: " + fieldDescriptor); + } return new Field( fieldScanner, this, diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index 034ef1f18f..6a10e47d18 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -87,7 +87,7 @@ public class ProtoWriteSupport extends WriteSupport< private boolean writeSpecsCompliant = false; private boolean unwrapProtoWrappers = false; - private CodegenMode codegenMode = CodegenMode.SUPPORT_COMPATIBLE; + private CodegenMode codegenMode = CodegenMode.DEFAULT; private RecordConsumer recordConsumer; private Class protoMessage; private Descriptor descriptor; @@ -209,12 +209,18 @@ public boolean protobufReflectionForExtensions() { @Override public boolean tryCodeGen(Class protoClass) { if (!ByteBuddyCodeGen.isGeneratedMessage(protoClass)) { - throw new IllegalStateException("protoClass is a GeneratedMessage: " + protoClass); + throw new UnsupportedOperationException("protoClass is not a GeneratedMessage: " + protoClass); } return ByteBuddyCodeGen.isByteBuddyAvailable(true); } }; + public static final CodegenMode DEFAULT = CodegenMode.SUPPORT_COMPATIBLE; + + public static CodegenMode orDefault(CodegenMode codegenMode) { + return codegenMode == null ? DEFAULT : codegenMode; + } + public abstract boolean ignoreCodeGenException(); public abstract boolean protobufReflectionForExtensions(); diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java index 57ad4d4f08..f3acd00adf 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java @@ -28,6 +28,9 @@ import com.google.protobuf.Message; import com.google.protobuf.Timestamp; import com.google.protobuf.util.Timestamps; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -38,9 +41,33 @@ import org.apache.parquet.proto.utils.ReadUsingMR; import org.apache.parquet.proto.utils.WriteUsingMR; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class ProtoInputOutputFormatTest { + @Parameterized.Parameters(name = "codegenMode: {0}") + public static Collection data() { + List data = new ArrayList<>(); + + List codegenModes = + new ArrayList<>(Arrays.asList(ProtoWriteSupport.CodegenMode.values())); + codegenModes.add(null); + + for (ProtoWriteSupport.CodegenMode codegenMode : codegenModes) { + data.add(new Object[] {codegenMode}); + } + + return data; + } + + private final ProtoWriteSupport.CodegenMode codegenMode; + + public ProtoInputOutputFormatTest(ProtoWriteSupport.CodegenMode codegenMode) { + this.codegenMode = codegenMode; + } + /** * Writes Protocol Buffer using first MR job, reads written file using * second job and compares input and output. @@ -241,7 +268,7 @@ public void testRepeatedIntMessageClassSchemaCompliant() throws Exception { .addRepeatedInt(2) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); @@ -264,7 +291,7 @@ public void testProto3RepeatedIntMessageClassSchemaCompliant() throws Exception .addRepeatedInt(2) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); @@ -327,7 +354,7 @@ public void testMapIntMessageClassSchemaCompliant() throws Exception { .putMapInt(2, 234) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); @@ -350,7 +377,7 @@ public void testProto3MapIntMessageClassSchemaCompliant() throws Exception { .putMapInt(2, 234) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); @@ -419,7 +446,7 @@ public void testRepeatedInnerMessageClassSchemaCompliant() throws Exception { TestProtobuf.InnerMessage.newBuilder().setTwo("two").build()) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); @@ -444,7 +471,7 @@ public void testProto3RepeatedInnerMessageClassSchemaCompliant() throws Exceptio TestProto3.InnerMessage.newBuilder().setTwo("two").build()) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); @@ -463,7 +490,7 @@ public void testProto3Defaults() throws Exception { TestProto3.SchemaConverterAllDatatypes msgEmpty = TestProto3.SchemaConverterAllDatatypes.newBuilder().build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty); @@ -508,7 +535,7 @@ public void testProto3AllTypes() throws Exception { TestProto3.SchemaConverterAllDatatypes dataBuilt = data.build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(dataBuilt); @@ -574,7 +601,7 @@ public void testProto3AllTypesMultiple() throws Exception { input[i] = d.build(); } - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(input); @@ -606,7 +633,7 @@ public void testProto3RepeatedMessages() throws Exception { top.addInnerBuilder().setTwo("Second inner"); top.addInnerBuilder().setThree("Third inner"); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); Path outputPath = new WriteUsingMR(conf).write(top.build()); @@ -643,7 +670,7 @@ public void testProto3TimestampMessageClass() throws Exception { TestProto3.DateTimeMessage msgNonEmpty = TestProto3.DateTimeMessage.newBuilder().setTimestamp(timestamp).build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); ReadUsingMR readUsingMR = new ReadUsingMR(); @@ -665,7 +692,7 @@ public void testProto3WrappedMessageClass() throws Exception { .setWrappedBool(BoolValue.of(true)) .build(); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); conf.setBoolean(ProtoWriteSupport.PB_UNWRAP_PROTO_WRAPPERS, true); Path outputPath = new WriteUsingMR(conf).write(msgEmpty, msgNonEmpty); ReadUsingMR readUsingMR = new ReadUsingMR(); @@ -681,9 +708,16 @@ public void testProto3WrappedMessageClass() throws Exception { /** * Runs job that writes input to file and then job reading data back. */ - public static List runMRJobs(Message... messages) throws Exception { - Path outputPath = new WriteUsingMR().write(messages); + public List runMRJobs(Message... messages) throws Exception { + Path outputPath = new WriteUsingMR(updateConfiguration(new Configuration())).write(messages); List result = new ReadUsingMR().read(outputPath); return result; } + + private Configuration updateConfiguration(Configuration configuration) { + if (codegenMode != null) { + ProtoWriteSupport.setCodegenMode(configuration, codegenMode); + } + return configuration; + } } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoParquetWriterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoParquetWriterTest.java index da84bf0047..efa3a3ae1b 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoParquetWriterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoParquetWriterTest.java @@ -21,9 +21,15 @@ import static org.apache.parquet.proto.TestUtils.readMessages; import static org.apache.parquet.proto.TestUtils.someTemporaryFilePath; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import com.google.protobuf.Descriptors; import com.google.protobuf.DynamicMessage; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.EnumSet; import java.util.List; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -31,8 +37,33 @@ import org.apache.parquet.hadoop.ParquetWriter; import org.apache.parquet.proto.test.TestProto3; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class ProtoParquetWriterTest { + + @Parameterized.Parameters(name = "codegenMode: {0}") + public static Collection data() { + List data = new ArrayList<>(); + + List codegenModes = + new ArrayList<>(Arrays.asList(ProtoWriteSupport.CodegenMode.values())); + codegenModes.add(null); + + for (ProtoWriteSupport.CodegenMode codegenMode : codegenModes) { + data.add(new Object[] {codegenMode}); + } + + return data; + } + + private final ProtoWriteSupport.CodegenMode codegenMode; + + public ProtoParquetWriterTest(ProtoWriteSupport.CodegenMode codegenMode) { + this.codegenMode = codegenMode; + } + @Test public void testProtoParquetWriterWithDynamicMessage() throws Exception { Path file = someTemporaryFilePath(); @@ -41,12 +72,23 @@ public void testProtoParquetWriterWithDynamicMessage() throws Exception { msg.setOne("oneValue"); DynamicMessage dynamicMessage = DynamicMessage.newBuilder(msg.build()).build(); - Configuration conf = new Configuration(); - ParquetWriter writer = ProtoParquetWriter.builder(file) - .withDescriptor(descriptor) - .withConf(conf) - .withWriteMode(ParquetFileWriter.Mode.OVERWRITE) - .build(); + Configuration conf = updateConfiguration(new Configuration()); + + ProtoWriteSupport.CodegenMode codegenModeOrDefault = ProtoWriteSupport.CodegenMode.orDefault(codegenMode); + EnumSet failingModes = EnumSet.of(ProtoWriteSupport.CodegenMode.REQUIRED_ALL); + + ParquetWriter writer; + try { + writer = ProtoParquetWriter.builder(file) + .withDescriptor(descriptor) + .withConf(conf) + .withWriteMode(ParquetFileWriter.Mode.OVERWRITE) + .build(); + } catch (UnsupportedOperationException e) { + assertTrue("codegenMode: " + codegenMode, failingModes.contains(codegenModeOrDefault)); + return; + } + assertFalse("codegenMode: " + codegenMode, failingModes.contains(codegenModeOrDefault)); writer.write(dynamicMessage); writer.close(); @@ -58,4 +100,11 @@ public void testProtoParquetWriterWithDynamicMessage() throws Exception { assertEquals(getFirst.getTwo(), ""); assertEquals(getFirst.getThree(), ""); } + + private Configuration updateConfiguration(Configuration configuration) { + if (codegenMode != null) { + ProtoWriteSupport.setCodegenMode(configuration, codegenMode); + } + return configuration; + } } diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java index 360da8b741..e5dd29e5a4 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoWriteSupportTest.java @@ -42,6 +42,10 @@ import java.io.IOException; import java.time.LocalDate; import java.time.LocalTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.EnumSet; import java.util.List; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -52,14 +56,45 @@ import org.apache.parquet.proto.test.TestProtobuf; import org.apache.parquet.proto.test.Trees; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.mockito.InOrder; import org.mockito.Mockito; +@RunWith(Parameterized.class) public class ProtoWriteSupportTest { + @Parameterized.Parameters(name = "codegenMode: {0}") + public static Collection data() { + List data = new ArrayList<>(); + + List codegenModes = + new ArrayList<>(Arrays.asList(ProtoWriteSupport.CodegenMode.values())); + codegenModes.add(null); + + for (ProtoWriteSupport.CodegenMode codegenMode : codegenModes) { + data.add(new Object[] {codegenMode}); + } + + return data; + } + + private final ProtoWriteSupport.CodegenMode codegenMode; + + public ProtoWriteSupportTest(ProtoWriteSupport.CodegenMode codegenMode) { + this.codegenMode = codegenMode; + } + private ProtoWriteSupport createReadConsumerInstance( Class cls, RecordConsumer readConsumerMock) { - return createReadConsumerInstance(cls, readConsumerMock, new Configuration()); + return createReadConsumerInstance(cls, readConsumerMock, updateConfiguration(new Configuration())); + } + + private Configuration updateConfiguration(Configuration configuration) { + if (codegenMode != null) { + ProtoWriteSupport.setCodegenMode(configuration, codegenMode); + } + return configuration; } private ProtoWriteSupport createReadConsumerInstance( @@ -126,7 +161,19 @@ public void testProto3SimplestDynamicMessage() throws Exception { Descriptors.Descriptor descriptor = TestProto3.InnerMessage.getDescriptor(); ProtoWriteSupport instance = new ProtoWriteSupport(descriptor); - instance.init(new Configuration()); + + Configuration configuration = updateConfiguration(new Configuration()); + + ProtoWriteSupport.CodegenMode codegenModeOrDefault = ProtoWriteSupport.CodegenMode.orDefault(codegenMode); + EnumSet failingModes = EnumSet.of(ProtoWriteSupport.CodegenMode.REQUIRED_ALL); + + try { + instance.init(configuration); + } catch (UnsupportedOperationException e) { + assertTrue("codegenMode: " + codegenMode, failingModes.contains(codegenModeOrDefault)); + return; + } + assertFalse("codegenMode: " + codegenMode, failingModes.contains(codegenModeOrDefault)); instance.prepareForWrite(readConsumerMock); TestProto3.InnerMessage.Builder msg = TestProto3.InnerMessage.newBuilder(); @@ -219,7 +266,7 @@ public void testRepeatedIntMessage() throws Exception { @Test public void testRepeatedIntMessageEmptySpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.RepeatedIntMessage.class, readConsumerMock, conf); @@ -255,7 +302,7 @@ public void testRepeatedIntMessageEmpty() throws Exception { @Test public void testProto3RepeatedIntMessageSpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.RepeatedIntMessage.class, readConsumerMock, conf); @@ -318,7 +365,7 @@ public void testProto3RepeatedIntMessage() throws Exception { @Test public void testProto3RepeatedIntMessageEmptySpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.RepeatedIntMessage.class, readConsumerMock, conf); @@ -354,7 +401,7 @@ public void testProto3RepeatedIntMessageEmpty() throws Exception { @Test public void testMapIntMessageSpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.MapIntMessage.class, readConsumerMock, conf); @@ -438,7 +485,7 @@ public void testMapIntMessage() throws Exception { @Test public void testMapIntMessageEmptySpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.MapIntMessage.class, readConsumerMock, conf); @@ -472,7 +519,7 @@ public void testMapIntMessageEmpty() throws Exception { @Test public void testProto3MapIntMessageSpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.MapIntMessage.class, readConsumerMock, conf); @@ -556,7 +603,7 @@ public void testProto3MapIntMessage() throws Exception { @Test public void testProto3MapIntMessageEmptySpecsCompliant() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.MapIntMessage.class, readConsumerMock, conf); @@ -620,7 +667,7 @@ public void testRepeatedInnerMessageMessage_message() throws Exception { @Test public void testRepeatedInnerMessageSpecsCompliantMessage_message() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.TopMessage.class, readConsumerMock, conf); @@ -694,7 +741,7 @@ public void testProto3RepeatedInnerMessageMessage_message() throws Exception { @Test public void testProto3RepeatedInnerMessageSpecsCompliantMessage_message() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.TopMessage.class, readConsumerMock, conf); @@ -737,7 +784,7 @@ public void testProto3RepeatedInnerMessageSpecsCompliantMessage_message() throws @Test public void testRepeatedInnerMessageSpecsCompliantMessage_scalar() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProtobuf.TopMessage.class, readConsumerMock, conf); @@ -871,7 +918,7 @@ public void testProto3RepeatedInnerMessageMessage_scalar() throws Exception { @Test public void testProto3RepeatedInnerMessageSpecsCompliantMessage_scalar() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setWriteSpecsCompliant(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.TopMessage.class, readConsumerMock, conf); @@ -990,7 +1037,7 @@ public void testProto3OptionalInnerMessage() throws Exception { Mockito.verifyNoMoreInteractions(readConsumerMock); } - @Test(expected = UnsupportedOperationException.class) + @Test public void testMessageWithExtensions() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); ProtoWriteSupport instance = @@ -1002,7 +1049,16 @@ public void testMessageWithExtensions() throws Exception { // will cause an exception. msg.setExtension(TestProtobuf.Airplane.wingSpan, 50); - instance.write(msg.build()); + ProtoWriteSupport.CodegenMode codegenModeOrDefault = ProtoWriteSupport.CodegenMode.orDefault(codegenMode); + EnumSet failingModes = + EnumSet.of(ProtoWriteSupport.CodegenMode.OFF, ProtoWriteSupport.CodegenMode.SUPPORT_COMPATIBLE); + try { + instance.write(msg.build()); + } catch (UnsupportedOperationException e) { + assertTrue("codegenMode: " + codegenMode, failingModes.contains(codegenModeOrDefault)); + return; + } + assertFalse("codegenMode: " + codegenMode, failingModes.contains(codegenModeOrDefault)); } @Test @@ -1066,7 +1122,7 @@ public void testMessageOneOfRoundTrip() throws IOException { @Test public void testMessageRecursion() { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoSchemaConverter.setMaxRecursion(conf, 1); ProtoWriteSupport spyWriter = createReadConsumerInstance(Trees.BinaryTree.class, readConsumerMock, conf); @@ -1120,7 +1176,7 @@ public void testMessageRecursion() { @Test public void testRepeatedRecursion() { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoSchemaConverter.setMaxRecursion(conf, 1); ProtoWriteSupport spyWriter = createReadConsumerInstance(Trees.WideTree.class, readConsumerMock, conf); @@ -1175,7 +1231,7 @@ public void testRepeatedRecursion() { @Test public void testMapRecursion() { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoSchemaConverter.setMaxRecursion(conf, 1); ProtoWriteSupport spyWriter = createReadConsumerInstance(Value.class, readConsumerMock, conf); @@ -1238,7 +1294,7 @@ public void testProto3DateTimeMessageUnwrapped() throws Exception { LocalTime time = LocalTime.of(15, 4, 3, 748_000_000); RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setUnwrapProtoWrappers(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.DateTimeMessage.class, readConsumerMock, conf); @@ -1315,7 +1371,7 @@ public void testProto3DateTimeMessageRoundTrip() throws Exception { @Test public void testProto3WrappedMessageUnwrapped() throws Exception { RecordConsumer readConsumerMock = Mockito.mock(RecordConsumer.class); - Configuration conf = new Configuration(); + Configuration conf = updateConfiguration(new Configuration()); ProtoWriteSupport.setUnwrapProtoWrappers(conf, true); ProtoWriteSupport instance = createReadConsumerInstance(TestProto3.WrappedMessage.class, readConsumerMock, conf); From 1d780f6af1eaa44c3b308ebcbb80cf932c3ed55c Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Wed, 22 Jan 2025 22:31:43 +0000 Subject: [PATCH 03/15] fix CICD errors: making some methods public to make code generation compatible with java8 without hacks --- .../parquet/proto/ByteBuddyCodeGen.java | 26 +++++++--- .../parquet/proto/ProtoWriteSupport.java | 52 +++++++++---------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 318723ed86..db19f60b69 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -73,6 +73,7 @@ import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; import net.bytebuddy.dynamic.scaffold.InstrumentedType; import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.MethodCall; import net.bytebuddy.implementation.SuperMethodCall; import net.bytebuddy.implementation.bytecode.ByteCodeAppender; import net.bytebuddy.implementation.bytecode.Removal; @@ -101,7 +102,7 @@ import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.LocalVar; import org.apache.parquet.schema.MessageType; -class ByteBuddyCodeGen { +public class ByteBuddyCodeGen { private static final AtomicLong BYTE_BUDDY_CLASS_SEQUENCE = new AtomicLong(); private static final GenerateMessageClasses GeneratedMessageV3 = @@ -185,6 +186,7 @@ static class CodeGenUtils { static final ResolvedReflection Reflection = new ResolvedReflection(); static class ResolvedReflection { + final Method MethodHandles_lookup = ReflectionUtil.getDeclaredMethod(MethodHandles.class, "lookup"); final RecordConsumerMethods RecordConsumer = new RecordConsumerMethods(); final ByteBuddyMessageWritersMethods ByteBuddyProto3FastMessageWriter = @@ -741,7 +743,7 @@ static Optional> classForName(String className) { } } - static class WriteSupport { + public static class WriteSupport { // in order to avoid class generation for the same proto descriptors, cache implementations. private static final Map.MessageWriter>> WRITERS_CACHE = new MapMaker().weakValues().makeMap(); @@ -1390,6 +1392,7 @@ public ByteBuddyMessageWritersCodeGen( classBuilder = new ByteBuddy() .subclass(ByteBuddyMessageWriters.class) + .modifiers(Visibility.PUBLIC) .name(ByteBuddyMessageWriters.class.getName() + "$Generated$" + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); @@ -1398,6 +1401,7 @@ public ByteBuddyMessageWritersCodeGen( generateConstructor(); generateMethods(); overrideSetFallbackFieldWriters(); + overrideGetLookup(); DynamicType.Unloaded unloaded = classBuilder.make(); @@ -1408,10 +1412,16 @@ public ByteBuddyMessageWritersCodeGen( // } byteBuddyMessageWritersClass = unloaded.load( - null, ClassLoadingStrategy.UsingLookup.of(MethodHandles.lookup())) + this.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) .getLoaded(); } + private void overrideGetLookup() { + classBuilder = classBuilder + .method(ElementMatchers.named("getLookup")) + .intercept(MethodCall.invoke(Reflection.MethodHandles_lookup)); + } + private void registerFallbackFieldWriterFields() { for (CodeGenFieldWriterFallback fieldWriterFallback : fieldWriterFallbacks.values()) { classBuilder = classBuilder.define(fieldWriterFallback.fieldWriter()); @@ -1465,7 +1475,8 @@ private void generateMethods() { private void generateConstructor() { classBuilder = classBuilder .constructor(ElementMatchers.any()) - .intercept(SuperMethodCall.INSTANCE.andThen(new ByteBuddyMessageWritersConstructor())); + .intercept(SuperMethodCall.INSTANCE.andThen(new ByteBuddyMessageWritersConstructor())) + .modifiers(Visibility.PUBLIC); } private void registerEnumFields() { @@ -2645,14 +2656,16 @@ public void accept(ProtoWriteSupport.MessageWriter messageWriter) { // this is subclassed with ByteBuddy, overriding the constructor, setProtoReflectionMessageWriters and adding // new fields and methods - abstract static class ByteBuddyMessageWriters { + public abstract static class ByteBuddyMessageWriters { private final ProtoWriteSupport protoWriteSupport; private final Map fastMessageWriters = new LinkedHashMap<>(); + private final MethodHandles.Lookup lookup; public ByteBuddyMessageWriters( ProtoWriteSupport.MessageWriter rootMessageWriter, ByteBuddyMessageWritersCodeGen.GeneratedElementsInfo generatedElementsInfo) { this.protoWriteSupport = rootMessageWriter.getProtoWriteSupport(); + this.lookup = getLookup(); final ProtoWriteSupport.FieldWriter[] fallbackFieldWriters = new ProtoWriteSupport.FieldWriter[generatedElementsInfo.fallbackFieldWriters.size()]; @@ -2699,7 +2712,6 @@ public Map enumNameNumberPairs(String enumTypeFullName) { public ProtoWriteSupport.MessageFieldsWriter getFastMessageWriter(Method method) { return fastMessageWriters.computeIfAbsent(method, k -> { - MethodHandles.Lookup lookup = MethodHandles.lookup(); Class messageOrBuilderInterface = method.getParameterTypes()[0]; try { Consumer consumer = @@ -2728,6 +2740,8 @@ public boolean writeAllFields(MessageOrBuilder messageOrBuilder) { } }); } + + protected abstract MethodHandles.Lookup getLookup(); } } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index 6a10e47d18..c9efe5e6b1 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -327,7 +327,7 @@ private Map enumMetadata() { return enumMetadata; } - class FieldWriter { + public class FieldWriter { String fieldName; int index = -1; @@ -345,7 +345,7 @@ void setIndex(int index) { /** * Used for writing repeated fields */ - void writeRawValue(Object value) {} + public void writeRawValue(Object value) {} /** * Used for writing nonrepeated (optional, required) fields @@ -538,7 +538,7 @@ void writeTopLevelMessage(Object value) { * Writes message as part of repeated field. It cannot start field */ @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { recordConsumer.startGroup(); writeAllFields((MessageOrBuilder) value); recordConsumer.endGroup(); @@ -614,7 +614,7 @@ class ArrayWriter extends FieldWriter { } @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { throw new UnsupportedOperationException("Array has no raw value"); } @@ -657,7 +657,7 @@ class RepeatedWriter extends FieldWriter { } @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { throw new UnsupportedOperationException("Array has no raw value"); } @@ -697,7 +697,7 @@ private void validatedMapping(Descriptor descriptor, GroupType parquetSchema) { class StringWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { Binary binaryString = Binary.fromString((String) value); recordConsumer.addBinary(binaryString); } @@ -705,7 +705,7 @@ final void writeRawValue(Object value) { class IntWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { recordConsumer.addInteger((Integer) value); } } @@ -713,7 +713,7 @@ final void writeRawValue(Object value) { class LongWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { recordConsumer.addLong((Long) value); } } @@ -730,7 +730,7 @@ public MapWriter(FieldWriter keyWriter, FieldWriter valueWriter) { } @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { Collection collection = (Collection) value; if (collection.isEmpty()) { return; @@ -761,14 +761,14 @@ final void writeRawValue(Object value) { class FloatWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { recordConsumer.addFloat((Float) value); } } class DoubleWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { recordConsumer.addDouble((Double) value); } } @@ -786,7 +786,7 @@ public EnumWriter(Descriptors.EnumDescriptor enumType) { } @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { Descriptors.EnumValueDescriptor enumValueDesc = (Descriptors.EnumValueDescriptor) value; Binary binary = Binary.fromString(enumValueDesc.getName()); recordConsumer.addBinary(binary); @@ -796,14 +796,14 @@ final void writeRawValue(Object value) { class BooleanWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { recordConsumer.addBoolean((Boolean) value); } } class BinaryWriter extends FieldWriter { @Override - final void writeRawValue(Object value) { + public final void writeRawValue(Object value) { // Non-ByteString values can happen when recursions gets truncated. ByteString byteString = value instanceof ByteString ? (ByteString) value @@ -819,7 +819,7 @@ final void writeRawValue(Object value) { class TimestampWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { Timestamp timestamp = (Timestamp) value; recordConsumer.addLong(Timestamps.toNanos(timestamp)); } @@ -827,7 +827,7 @@ void writeRawValue(Object value) { class DateWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { Date date = (Date) value; LocalDate localDate = LocalDate.of(date.getYear(), date.getMonth(), date.getDay()); recordConsumer.addInteger((int) localDate.toEpochDay()); @@ -836,7 +836,7 @@ void writeRawValue(Object value) { class TimeWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { com.google.type.TimeOfDay timeOfDay = (com.google.type.TimeOfDay) value; LocalTime localTime = LocalTime.of( timeOfDay.getHours(), timeOfDay.getMinutes(), timeOfDay.getSeconds(), timeOfDay.getNanos()); @@ -846,56 +846,56 @@ void writeRawValue(Object value) { class DoubleValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addDouble(((DoubleValue) value).getValue()); } } class FloatValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addFloat(((FloatValue) value).getValue()); } } class Int64ValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addLong(((Int64Value) value).getValue()); } } class UInt64ValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addLong(((UInt64Value) value).getValue()); } } class Int32ValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addInteger(((Int32Value) value).getValue()); } } class UInt32ValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addLong(((UInt32Value) value).getValue()); } } class BoolValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { recordConsumer.addBoolean(((BoolValue) value).getValue()); } } class StringValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { Binary binaryString = Binary.fromString(((StringValue) value).getValue()); recordConsumer.addBinary(binaryString); } @@ -903,7 +903,7 @@ void writeRawValue(Object value) { class BytesValueWriter extends FieldWriter { @Override - void writeRawValue(Object value) { + public void writeRawValue(Object value) { byte[] byteArray = ((BytesValue) value).getValue().toByteArray(); Binary binary = Binary.fromConstantByteArray(byteArray); recordConsumer.addBinary(binary); From 5b3b74749275f487c93a3238f2b16d4fb6fa1ebc Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Thu, 23 Jan 2025 02:34:36 +0000 Subject: [PATCH 04/15] benchmark ProtoWriteSupport with ByteBuddy --- parquet-benchmarks/pom.xml | 29 +++ .../benchmarks/ProtoDataGenerator.java | 218 ++++++++++++++++++ .../benchmarks/ProtoWriteBenchmarks.java | 50 ++++ .../parquet/benchmarks/WriteBenchmarks.java | 3 +- .../src/main/protobuf/messages.proto | 216 +++++++++++++++++ 5 files changed, 515 insertions(+), 1 deletion(-) create mode 100644 parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java create mode 100644 parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoWriteBenchmarks.java create mode 100644 parquet-benchmarks/src/main/protobuf/messages.proto diff --git a/parquet-benchmarks/pom.xml b/parquet-benchmarks/pom.xml index 77df2c101d..7246e1a754 100644 --- a/parquet-benchmarks/pom.xml +++ b/parquet-benchmarks/pom.xml @@ -34,6 +34,8 @@ 1.37 parquet-benchmarks + 3.25.5 + 1.14.18 @@ -52,6 +54,11 @@ parquet-common ${project.version} + + org.apache.parquet + parquet-protobuf + ${project.version} + org.apache.hadoop hadoop-client @@ -82,10 +89,32 @@ slf4j-api ${slf4j.version} + + net.bytebuddy + byte-buddy + ${byte-buddy.version} + + + com.github.os72 + protoc-jar-maven-plugin + 3.11.4 + + + generate-sources + generate-sources + + run + + + com.google.protobuf:protoc:${protobuf.version} + + + + org.apache.maven.plugins maven-compiler-plugin diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java new file mode 100644 index 0000000000..23eed270df --- /dev/null +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.benchmarks; + +import static java.util.UUID.randomUUID; +import static org.apache.parquet.benchmarks.BenchmarkConstants.DICT_PAGE_SIZE; +import static org.apache.parquet.benchmarks.BenchmarkUtils.exists; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.function.IntFunction; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.benchmarks.Messages.Test1; +import org.apache.parquet.benchmarks.Messages.Test100Int32; +import org.apache.parquet.benchmarks.Messages.Test30Int32; +import org.apache.parquet.benchmarks.Messages.Test30String; +import org.apache.parquet.column.ParquetProperties; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.metadata.CompressionCodecName; +import org.apache.parquet.proto.ProtoParquetWriter; +import org.apache.parquet.proto.ProtoWriteSupport; + +public class ProtoDataGenerator extends DataGenerator { + + private final Class protoClass; + private final ProtoWriteSupport.CodegenMode codegenMode; + private final RecordGeneratorFactory recordGeneratorFactory; + + public ProtoDataGenerator(Class protoClass, ProtoWriteSupport.CodegenMode codegenMode) { + this.protoClass = protoClass; + this.codegenMode = codegenMode; + this.recordGeneratorFactory = (RecordGeneratorFactory) GENERATORS.get(protoClass); + } + + private interface RecordGeneratorFactory { + RecordGenerator newRecordGenerator(int fixedLenByteArraySize); + } + + private interface RecordGenerator extends IntFunction {} + + private static final RecordGeneratorFactory TEST1 = fixedLenByteArraySize -> { + final Test1.Builder builder = Test1.newBuilder(); + String fixedLenStr = generateFixedLenStr(fixedLenByteArraySize); + + return i -> builder.setBinaryField(ByteString.copyFromUtf8(randomUUID().toString())) + .setInt32Field(i) + .setInt64Field(64L) + .setBooleanField(true) + .setFloatField(1.0f) + .setDoubleField(2.0d) + .setStringField(fixedLenStr); + }; + + private static final RecordGeneratorFactory TEST_30_INT32 = fixedLenByteArraySize -> { + final Test30Int32.Builder builder = Test30Int32.newBuilder(); + + return i -> builder + .setField1(i) + .setField2(i) + .setField3(i) + .setField4(i) + .setField5(i) + .setField6(i) + .setField7(i) + .setField8(i) + .setField9(i) + .setField10(i) + .setField11(i) + .setField12(i) + .setField13(i) + .setField14(i) + .setField15(i) + .setField16(i) + .setField17(i) + .setField18(i) + .setField19(i) + .setField20(i) + .setField21(i) + .setField22(i) + .setField23(i) + .setField24(i) + .setField25(i) + .setField26(i) + .setField27(i) + .setField28(i) + .setField29(i) + .setField30(i); + }; + + private static final RecordGeneratorFactory TEST_100_INT32 = fixedLenByteArraySize -> { + final Test100Int32.Builder builder = Test100Int32.newBuilder(); + + return i -> builder + .setF1(i).setF2(i).setF3(i).setF4(i).setF5(i).setF6(i).setF7(i).setF8(i).setF9(i).setF10(i) + .setF11(i).setF12(i).setF13(i).setF14(i).setF15(i).setF16(i).setF17(i).setF18(i).setF19(i).setF20(i) + .setF21(i).setF22(i).setF23(i).setF24(i).setF25(i).setF26(i).setF27(i).setF28(i).setF29(i).setF30(i) + .setF31(i).setF32(i).setF33(i).setF34(i).setF35(i).setF36(i).setF37(i).setF38(i).setF39(i).setF40(i) + .setF41(i).setF42(i).setF43(i).setF44(i).setF45(i).setF46(i).setF47(i).setF48(i).setF49(i).setF50(i) + .setF51(i).setF52(i).setF53(i).setF54(i).setF55(i).setF56(i).setF57(i).setF58(i).setF59(i).setF60(i) + .setF61(i).setF62(i).setF63(i).setF64(i).setF65(i).setF66(i).setF67(i).setF68(i).setF69(i).setF70(i) + .setF71(i).setF72(i).setF73(i).setF74(i).setF75(i).setF76(i).setF77(i).setF78(i).setF79(i).setF80(i) + .setF81(i).setF82(i).setF83(i).setF84(i).setF85(i).setF86(i).setF87(i).setF88(i).setF89(i).setF90(i) + .setF91(i).setF92(i).setF93(i).setF94(i).setF95(i).setF96(i).setF97(i).setF98(i).setF99(i).setF100(i); + }; + + private static final RecordGeneratorFactory TEST_30_STRING = fixedLenByteArraySize -> { + final Test30String.Builder builder = Test30String.newBuilder(); + + return i -> builder + .setField1("setField1:" + i) + .setField2("setField2:" + i) + .setField3("setField3:" + i) + .setField4("setField4:" + i) + .setField5("setField5:" + i) + .setField6("setField6:" + i) + .setField7("setField7:" + i) + .setField8("setField8:" + i) + .setField9("setField9:" + i) + .setField10("setField10:" + i) + .setField11("setField11:" + i) + .setField12("setField12:" + i) + .setField13("setField13:" + i) + .setField14("setField14:" + i) + .setField15("setField15:" + i) + .setField16("setField16:" + i) + .setField17("setField17:" + i) + .setField18("setField18:" + i) + .setField19("setField19:" + i) + .setField20("setField20:" + i) + .setField21("setField21:" + i) + .setField22("setField22:" + i) + .setField23("setField23:" + i) + .setField24("setField24:" + i) + .setField25("setField25:" + i) + .setField26("setField26:" + i) + .setField27("setField27:" + i) + .setField28("setField28:" + i) + .setField29("setField29:" + i) + .setField30("setField30:" + i); + }; + + private static String generateFixedLenStr(int fixedLenByteArraySize) { + // generate some data for the fixed len byte array field + char[] chars = new char[fixedLenByteArraySize]; + Arrays.fill(chars, '*'); + return String.copyValueOf(chars); + } + + private static final Map, RecordGeneratorFactory> GENERATORS = new HashMap() { + { + put(Test1.class, TEST1); + put(Test30Int32.class, TEST_30_INT32); + put(Test30String.class, TEST_30_STRING); + put(Test100Int32.class, TEST_100_INT32); + } + }; + + public void generateData( + Path outFile, + Configuration configuration, + ParquetProperties.WriterVersion version, + int blockSize, + int pageSize, + int fixedLenByteArraySize, + CompressionCodecName codec, + int nRows) + throws IOException { + if (exists(configuration, outFile)) { + System.out.println("File already exists " + outFile); + return; + } + + System.out.println("Generating data @ " + outFile + " with codegenMode " + codegenMode); + + ProtoWriteSupport.setCodegenMode(configuration, codegenMode); + ProtoWriteSupport.setSchema(configuration, protoClass); + + ParquetWriter writer = ProtoParquetWriter.builder(outFile) + .withMessage(protoClass) + .withConf(configuration) + .withCompressionCodec(codec) + .withRowGroupSize((long) blockSize) + .withPageSize(pageSize) + .enableDictionaryEncoding() + .withDictionaryPageSize(DICT_PAGE_SIZE) + .withValidation(false) + .withWriterVersion(version) + .build(); + + RecordGenerator recordGenerator = recordGeneratorFactory.newRecordGenerator(fixedLenByteArraySize); + + for (int i = 0; i < nRows; i++) { + writer.write(recordGenerator.apply(i)); + } + writer.close(); + } +} diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoWriteBenchmarks.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoWriteBenchmarks.java new file mode 100644 index 0000000000..0e29c11cff --- /dev/null +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoWriteBenchmarks.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.benchmarks; + +import static org.openjdk.jmh.annotations.Scope.Thread; + +import com.google.protobuf.Message; +import org.apache.parquet.proto.ProtoWriteSupport; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; + +@State(Thread) +public class ProtoWriteBenchmarks extends WriteBenchmarks { + @Param({"OFF", "REQUIRED_ALL"}) + public ProtoWriteSupport.CodegenMode codegenMode; + + @Param({"Test30Int32", "Test100Int32", "Test30String", "Test1"}) + public String protoClass; + + @Setup(Level.Iteration) + public void setup() { + Class messageClass; + try { + messageClass = (Class) Class.forName("org.apache.parquet.benchmarks.Messages$" + protoClass); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + dataGenerator = new ProtoDataGenerator<>(messageClass, codegenMode); + // clean existing test data at the beginning of each iteration + dataGenerator.cleanup(); + } +} diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/WriteBenchmarks.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/WriteBenchmarks.java index 41f961de44..ff53ff45d2 100644 --- a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/WriteBenchmarks.java +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/WriteBenchmarks.java @@ -50,10 +50,11 @@ @State(Thread) public class WriteBenchmarks { - private DataGenerator dataGenerator = new DataGenerator(); + protected DataGenerator dataGenerator; @Setup(Level.Iteration) public void setup() { + dataGenerator = new DataGenerator(); // clean existing test data at the beginning of each iteration dataGenerator.cleanup(); } diff --git a/parquet-benchmarks/src/main/protobuf/messages.proto b/parquet-benchmarks/src/main/protobuf/messages.proto new file mode 100644 index 0000000000..fb9ad19bd4 --- /dev/null +++ b/parquet-benchmarks/src/main/protobuf/messages.proto @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +syntax = "proto3"; + +package Benchmarks; + +option java_package = "org.apache.parquet.benchmarks"; + +// more or less mimics the data structure defined in the original DataGenerator +message Test1 { + bytes binary_field = 1; + int32 int32_field = 2; + int64 int64_field = 3; + bool boolean_field = 4; + float float_field = 5; + double double_field = 6; + string string_field = 7; +} + +message Test30Int32 { + int32 field1 = 1; + int32 field2 = 2; + int32 field3 = 3; + int32 field4 = 4; + int32 field5 = 5; + int32 field6 = 6; + int32 field7 = 7; + int32 field8 = 8; + int32 field9 = 9; + int32 field10 = 10; + + int32 field11 = 11; + int32 field12 = 12; + int32 field13 = 13; + int32 field14 = 14; + int32 field15 = 15; + int32 field16 = 16; + int32 field17 = 17; + int32 field18 = 18; + int32 field19 = 19; + int32 field20 = 20; + + int32 field21 = 21; + int32 field22 = 22; + int32 field23 = 23; + int32 field24 = 24; + int32 field25 = 25; + int32 field26 = 26; + int32 field27 = 27; + int32 field28 = 28; + int32 field29 = 29; + int32 field30 = 30; +} + +message Test30String { + string field1 = 1; + string field2 = 2; + string field3 = 3; + string field4 = 4; + string field5 = 5; + string field6 = 6; + string field7 = 7; + string field8 = 8; + string field9 = 9; + string field10 = 10; + + string field11 = 11; + string field12 = 12; + string field13 = 13; + string field14 = 14; + string field15 = 15; + string field16 = 16; + string field17 = 17; + string field18 = 18; + string field19 = 19; + string field20 = 20; + + string field21 = 21; + string field22 = 22; + string field23 = 23; + string field24 = 24; + string field25 = 25; + string field26 = 26; + string field27 = 27; + string field28 = 28; + string field29 = 29; + string field30 = 30; +} + +message Test100Int32 { + int32 f1 = 1; + int32 f2 = 2; + int32 f3 = 3; + int32 f4 = 4; + int32 f5 = 5; + int32 f6 = 6; + int32 f7 = 7; + int32 f8 = 8; + int32 f9 = 9; + int32 f10 = 10; + + int32 f11 = 11; + int32 f12 = 12; + int32 f13 = 13; + int32 f14 = 14; + int32 f15 = 15; + int32 f16 = 16; + int32 f17 = 17; + int32 f18 = 18; + int32 f19 = 19; + int32 f20 = 20; + + int32 f21 = 21; + int32 f22 = 22; + int32 f23 = 23; + int32 f24 = 24; + int32 f25 = 25; + int32 f26 = 26; + int32 f27 = 27; + int32 f28 = 28; + int32 f29 = 29; + int32 f30 = 30; + + int32 f31 = 31; + int32 f32 = 32; + int32 f33 = 33; + int32 f34 = 34; + int32 f35 = 35; + int32 f36 = 36; + int32 f37 = 37; + int32 f38 = 38; + int32 f39 = 39; + int32 f40 = 40; + + int32 f41 = 41; + int32 f42 = 42; + int32 f43 = 43; + int32 f44 = 44; + int32 f45 = 45; + int32 f46 = 46; + int32 f47 = 47; + int32 f48 = 48; + int32 f49 = 49; + int32 f50 = 50; + + int32 f51 = 51; + int32 f52 = 52; + int32 f53 = 53; + int32 f54 = 54; + int32 f55 = 55; + int32 f56 = 56; + int32 f57 = 57; + int32 f58 = 58; + int32 f59 = 59; + int32 f60 = 60; + + int32 f61 = 61; + int32 f62 = 62; + int32 f63 = 63; + int32 f64 = 64; + int32 f65 = 65; + int32 f66 = 66; + int32 f67 = 67; + int32 f68 = 68; + int32 f69 = 69; + int32 f70 = 70; + + int32 f71 = 71; + int32 f72 = 72; + int32 f73 = 73; + int32 f74 = 74; + int32 f75 = 75; + int32 f76 = 76; + int32 f77 = 77; + int32 f78 = 78; + int32 f79 = 79; + int32 f80 = 80; + + int32 f81 = 81; + int32 f82 = 82; + int32 f83 = 83; + int32 f84 = 84; + int32 f85 = 85; + int32 f86 = 86; + int32 f87 = 87; + int32 f88 = 88; + int32 f89 = 89; + int32 f90 = 90; + + int32 f91 = 91; + int32 f92 = 92; + int32 f93 = 93; + int32 f94 = 94; + int32 f95 = 95; + int32 f96 = 96; + int32 f97 = 97; + int32 f98 = 98; + int32 f99 = 99; + int32 f100 = 100; +} From 4db011b1ac293633cf607235e9575ad29ffb0474 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Sun, 13 Apr 2025 20:23:00 +0100 Subject: [PATCH 05/15] adding test for map where v is Enum (proto2 and proto3) --- .../proto/ProtoInputOutputFormatTest.java | 3 ++ .../proto/ProtoRecordConverterTest.java | 9 ++++++ .../proto/ProtoSchemaConverterTest.java | 19 ++++++++++++ .../src/test/resources/EnumProto3.proto | 29 +++++++++++++++++++ .../src/test/resources/TestProto3.proto | 1 + .../src/test/resources/TestProtobuf.proto | 5 ++++ 6 files changed, 66 insertions(+) create mode 100644 parquet-protobuf/src/test/resources/EnumProto3.proto diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java index f3acd00adf..35e753d8bf 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoInputOutputFormatTest.java @@ -532,6 +532,7 @@ public void testProto3AllTypes() throws Exception { data.setOptionalUInt32(1000 * 1000 * 8); data.setOptionalUInt64(1000L * 1000 * 1000 * 9); data.getOptionalMessageBuilder().setSomeId(1984); + data.putOptionalMapEnum(1000L, TestProto3.SchemaConverterAllDatatypes.TestEnum.SECOND); TestProto3.SchemaConverterAllDatatypes dataBuilt = data.build(); @@ -571,6 +572,8 @@ public void testProto3AllTypes() throws Exception { assertEquals(1000 * 1000 * 8, o.getOptionalUInt32()); assertEquals(1000L * 1000 * 1000 * 9, o.getOptionalUInt64()); assertEquals(1984, o.getOptionalMessage().getSomeId()); + assertEquals(1, o.getOptionalMapEnumCount()); + assertEquals(TestProto3.SchemaConverterAllDatatypes.TestEnum.SECOND, o.getOptionalMapEnumOrThrow(1000)); } @Test diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoRecordConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoRecordConverterTest.java index 65b91da688..4ad4102c1b 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoRecordConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoRecordConverterTest.java @@ -27,6 +27,7 @@ import com.google.protobuf.ByteString; import java.util.List; +import org.apache.parquet.proto.test.EnumProto3OuterClass; import org.apache.parquet.proto.test.TestProto3; import org.apache.parquet.proto.test.TestProtobuf; import org.junit.Test; @@ -56,6 +57,9 @@ public void testAllTypes() throws Exception { data.setOptionalUInt64(1000L * 1000 * 1000 * 9); data.getOptionalMessageBuilder().setSomeId(1984); data.getPbGroupBuilder().setGroupInt(1492); + data.setOptionalEnumProto3(EnumProto3OuterClass.EnumProto3.SECOND); + data.putOptionalMapEnumProto2(1000, SchemaConverterAllDatatypes.TestEnum.SECOND); + data.putOptionalMapEnumProto3(1000, EnumProto3OuterClass.EnumProto3.SECOND); SchemaConverterAllDatatypes dataBuilt = data.build(); data.clear(); @@ -84,6 +88,11 @@ public void testAllTypes() throws Exception { assertEquals(1000L * 1000 * 1000 * 9, o.getOptionalUInt64()); assertEquals(1984, o.getOptionalMessage().getSomeId()); assertEquals(1492, o.getPbGroup().getGroupInt()); + assertEquals(EnumProto3OuterClass.EnumProto3.SECOND, o.getOptionalEnumProto3()); + assertEquals(1, o.getOptionalMapEnumProto2Count()); + assertEquals(SchemaConverterAllDatatypes.TestEnum.SECOND, o.getOptionalMapEnumProto2OrThrow(1000)); + assertEquals(1, o.getOptionalMapEnumProto3Count()); + assertEquals(EnumProto3OuterClass.EnumProto3.SECOND, o.getOptionalMapEnumProto3OrThrow(1000)); } @Test diff --git a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java index 5240be5a36..1efdac1727 100644 --- a/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java +++ b/parquet-protobuf/src/test/java/org/apache/parquet/proto/ProtoSchemaConverterTest.java @@ -94,6 +94,19 @@ public void testConvertAllDatatypes() { " optional int32 groupInt = 2;", " }", " optional binary optionalEnum (ENUM) = 18;", + " optional binary optionalEnumProto3 (ENUM) = 19;", + " optional group optionalMapEnumProto2 (MAP) = 20 {", + " repeated group key_value {", + " required int64 key;", + " optional binary value (ENUM);", + " }", + " }", + " optional group optionalMapEnumProto3 (MAP) = 21 {", + " repeated group key_value {", + " required int64 key;", + " optional binary value (ENUM);", + " }", + " }", "}"); testConversion(TestProtobuf.SchemaConverterAllDatatypes.class, expectedSchema); @@ -135,6 +148,12 @@ public void testProto3ConvertAllDatatypes() { " }", " }", " }", + " optional group optionalMapEnum (MAP) = 22 {", + " repeated group key_value {", + " required int64 key;", + " optional binary value (ENUM);", + " }", + " }", "}"); testConversion(TestProto3.SchemaConverterAllDatatypes.class, expectedSchema); diff --git a/parquet-protobuf/src/test/resources/EnumProto3.proto b/parquet-protobuf/src/test/resources/EnumProto3.proto new file mode 100644 index 0000000000..2946503b8f --- /dev/null +++ b/parquet-protobuf/src/test/resources/EnumProto3.proto @@ -0,0 +1,29 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +syntax = "proto3"; + +package EnumProto3; + +option java_package = "org.apache.parquet.proto.test"; + +enum EnumProto3 { + FIRST = 0; + SECOND = 1; +} diff --git a/parquet-protobuf/src/test/resources/TestProto3.proto b/parquet-protobuf/src/test/resources/TestProto3.proto index c303fd1f5d..cde48dd1c3 100644 --- a/parquet-protobuf/src/test/resources/TestProto3.proto +++ b/parquet-protobuf/src/test/resources/TestProto3.proto @@ -86,6 +86,7 @@ message SchemaConverterAllDatatypes { string someString = 20; } map optionalMap = 21; + map optionalMapEnum = 22; } message SchemaConverterRepetition { diff --git a/parquet-protobuf/src/test/resources/TestProtobuf.proto b/parquet-protobuf/src/test/resources/TestProtobuf.proto index fe0cbe8327..c634432b8e 100644 --- a/parquet-protobuf/src/test/resources/TestProtobuf.proto +++ b/parquet-protobuf/src/test/resources/TestProtobuf.proto @@ -21,6 +21,8 @@ syntax = "proto2"; package TestProtobuf; +import "EnumProto3.proto"; + option java_package = "org.apache.parquet.proto.test"; // original Dremel paper structures: Original paper used groups, not internal @@ -77,6 +79,9 @@ message Links { SECOND = 1; } optional TestEnum optionalEnum = 18; + optional EnumProto3.EnumProto3 optionalEnumProto3 = 19; + map optionalMapEnumProto2 = 20; + map optionalMapEnumProto3 = 21; } message SchemaConverterRepetition { From 9701b007f803f9c4752f7e7a0155d9f1f8ca1541 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Fri, 25 Apr 2025 22:36:44 +0100 Subject: [PATCH 06/15] begin read support --- parquet-protobuf/pom.xml | 1 + .../parquet/proto/ByteBuddyCodeGen.java | 15 +++++ .../parquet/proto/ProtoReadSupport.java | 60 ++++++++++++++++++- .../parquet/proto/ProtoWriteSupport.java | 2 +- 4 files changed, 76 insertions(+), 2 deletions(-) diff --git a/parquet-protobuf/pom.xml b/parquet-protobuf/pom.xml index e547cf6946..509925cebe 100644 --- a/parquet-protobuf/pom.xml +++ b/parquet-protobuf/pom.xml @@ -35,6 +35,7 @@ 2.50.0 1.4.3 1.14.18 + diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index db19f60b69..0b26d544fa 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -64,6 +64,7 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.stream.Stream; +import com.twitter.elephantbird.util.Protobufs; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; import net.bytebuddy.description.method.MethodDescription; @@ -95,8 +96,10 @@ import net.bytebuddy.jar.asm.Opcodes; import net.bytebuddy.matcher.ElementMatchers; import net.bytebuddy.utility.JavaConstant; +import org.apache.parquet.conf.ParquetConfiguration; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.RecordConsumer; +import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Codegen; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Implementations; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.LocalVar; @@ -2744,4 +2747,16 @@ public boolean writeAllFields(MessageOrBuilder messageOrBuilder) { protected abstract MethodHandles.Lookup getLookup(); } } + + public static class ReadSupport { + + static RecordMaterializer tryEnhanceRecordMaterializer( + org.apache.parquet.proto.ProtoRecordMaterializer protoRecordMaterializer, + ProtoReadSupport.CodegenMode codegenMode, + ParquetConfiguration configuration) { + + return protoRecordMaterializer; + } + + } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java index 484c4932f1..7ac40cd9a8 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java @@ -39,6 +39,59 @@ public class ProtoReadSupport extends ReadSupport { public static final String PB_CLASS = "parquet.proto.class"; public static final String PB_DESCRIPTOR = "parquet.proto.descriptor"; + public static final String PB_CODEGEN = "parquet.proto.readCodegen"; + + public enum CodegenMode { + OFF { + @Override + public boolean ignoreCodeGenException() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean tryCodeGen(Class protoClass) { + return false; + } + }, + + SUPPORT { + @Override + public boolean ignoreCodeGenException() { + return true; + } + + @Override + public boolean tryCodeGen(Class protoClass) { + return ByteBuddyCodeGen.isGeneratedMessage(protoClass) && ByteBuddyCodeGen.isByteBuddyAvailable(false); + } + }, + + REQUIRED { + @Override + public boolean ignoreCodeGenException() { + return false; + } + + @Override + public boolean tryCodeGen(Class protoClass) { + if (!ByteBuddyCodeGen.isGeneratedMessage(protoClass)) { + throw new UnsupportedOperationException("protoClass is not a GeneratedMessage: " + protoClass); + } + return ByteBuddyCodeGen.isByteBuddyAvailable(true); + } + }; + + public static final ProtoReadSupport.CodegenMode DEFAULT = CodegenMode.SUPPORT; + + public static ProtoReadSupport.CodegenMode orDefault(ProtoReadSupport.CodegenMode codegenMode) { + return codegenMode == null ? DEFAULT : codegenMode; + } + + public abstract boolean ignoreCodeGenException(); + + public abstract boolean tryCodeGen(Class protoClass); + } + public static void setRequestedProjection(Configuration configuration, String requestedProjection) { configuration.set(PB_REQUESTED_PROJECTION, requestedProjection); @@ -104,6 +157,11 @@ public RecordMaterializer prepareForRead( MessageType requestedSchema = readContext.getRequestedSchema(); Class protobufClass = Protobufs.getProtobufClass(headerProtoClass); - return new ProtoRecordMaterializer(configuration, requestedSchema, protobufClass, keyValueMetaData); + ProtoRecordMaterializer protoRecordMaterializer = new ProtoRecordMaterializer(configuration, requestedSchema, protobufClass, keyValueMetaData); + + CodegenMode codegenMode = ProtoReadSupport.CodegenMode.valueOf(configuration.get(PB_CODEGEN, CodegenMode.DEFAULT.name())); + return codegenMode.tryCodeGen(protobufClass) + ? ByteBuddyCodeGen.ReadSupport.tryEnhanceRecordMaterializer(protoRecordMaterializer, codegenMode, configuration) + : protoRecordMaterializer; } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java index c9efe5e6b1..e5b2f59f30 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoWriteSupport.java @@ -83,7 +83,7 @@ public class ProtoWriteSupport extends WriteSupport< public static final String PB_UNWRAP_PROTO_WRAPPERS = "parquet.proto.unwrapProtoWrappers"; - public static final String PB_CODEGEN = "parquet.proto.codegen"; + public static final String PB_CODEGEN = "parquet.proto.writeCodegen"; private boolean writeSpecsCompliant = false; private boolean unwrapProtoWrappers = false; From 2313c92b6227049ee43f0d053dfeb3c46a61ec2b Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Fri, 25 Apr 2025 22:57:09 +0100 Subject: [PATCH 07/15] ParentValueContainer and ProtoGroupConverter for introspection and patching --- .../parquet/proto/ProtoMessageConverter.java | 127 ++++++++++++++---- .../parquet/proto/ProtoReadSupport.java | 5 + .../parquet/proto/ProtoRecordConverter.java | 2 +- 3 files changed, 104 insertions(+), 30 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index d446598f06..73696dc268 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -61,6 +61,7 @@ import org.apache.parquet.io.api.Converter; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; +import org.apache.parquet.proto.ProtoReadSupport.ProtoGroupConverter; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.IncompatibleSchemaModificationException; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -73,13 +74,10 @@ * Converts Protocol Buffer message (both top level and inner) to parquet. * This is internal class, use {@link ProtoRecordConverter}. */ -class ProtoMessageConverter extends GroupConverter { +class ProtoMessageConverter extends ProtoGroupConverter { private static final Logger LOG = LoggerFactory.getLogger(ProtoMessageConverter.class); - private static final ParentValueContainer DUMMY_PVC = new ParentValueContainer() { - @Override - public void add(Object value) {} - }; + private static final ParentValueContainer DUMMY_PVC = new DummyParentValueContainer(); protected final ParquetConfiguration conf; protected final Converter[] converters; @@ -240,19 +238,9 @@ protected Converter newMessageConverter( ParentValueContainer parent; if (isRepeated) { - parent = new ParentValueContainer() { - @Override - public void add(Object value) { - parentBuilder.addRepeatedField(fieldDescriptor, value); - } - }; + parent = new AddRepeatedFieldParentValueContainer(parentBuilder, fieldDescriptor); } else { - parent = new ParentValueContainer() { - @Override - public void add(Object value) { - parentBuilder.setField(fieldDescriptor, value); - } - }; + parent = new SetFieldParentValueContainer(parentBuilder, fieldDescriptor); } LogicalTypeAnnotation logicalTypeAnnotation = parquetType.getLogicalTypeAnnotation(); @@ -361,12 +349,75 @@ public Message.Builder getBuilder() { return myBuilder; } + @Override + int getFieldCount() { + return converters.length; + } + abstract static class ParentValueContainer { /** * Adds the value to the parent. */ public abstract void add(Object value); + + public void addInt(int value) { + add(value); + } + + public void addLong(long value) { + add(value); + } + + public void addDouble(double value) { + add(value); + } + + public void addFloat(float value) { + add(value); + } + + public void addBoolean(boolean value) { + add(value); + } + } + + static class DummyParentValueContainer extends ParentValueContainer { + @Override + public void add(Object value) { + } + } + + static class SetFieldParentValueContainer extends ParentValueContainer { + + private final Message.Builder parent; + private final Descriptors.FieldDescriptor fieldDescriptor; + + public SetFieldParentValueContainer(Message.Builder parent, Descriptors.FieldDescriptor fieldDescriptor) { + this.parent = parent; + this.fieldDescriptor = fieldDescriptor; + } + + @Override + public void add(Object value) { + parent.setField(fieldDescriptor, value); + } + } + + static class AddRepeatedFieldParentValueContainer extends ParentValueContainer { + + private final Message.Builder parent; + private final Descriptors.FieldDescriptor fieldDescriptor; + + public AddRepeatedFieldParentValueContainer(Message.Builder parent, Descriptors.FieldDescriptor fieldDescriptor) { + this.parent = parent; + this.fieldDescriptor = fieldDescriptor; + } + + @Override + public void add(Object value) { + parent.addRepeatedField(fieldDescriptor, value); + } } final class ProtoEnumConverter extends PrimitiveConverter { @@ -835,8 +886,9 @@ public void addBinary(Binary binary) { * a repeated group named 'list', itself containing only one field called 'element' of the type of the repeated * object (can be a primitive as in this example or a group in case of a repeated message in protobuf). */ - final class ListConverter extends GroupConverter { + final class ListConverter extends ProtoGroupConverter { private final Converter converter; + private final Converter wrapperConverter; public ListConverter( Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { @@ -862,36 +914,48 @@ public ListConverter( Type elementType = listType.getType("element"); converter = newMessageConverter(parentBuilder, fieldDescriptor, elementType); - } - - @Override - public Converter getConverter(int fieldIndex) { - if (fieldIndex > 0) { - throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); - } + wrapperConverter = new ProtoGroupConverter() { + @Override + int getFieldCount() { + return 1; + } - return new GroupConverter() { @Override public Converter getConverter(int fieldIndex) { return converter; } @Override - public void start() {} + public void start() { + } @Override - public void end() {} + public void end() { + } }; } + @Override + public Converter getConverter(int fieldIndex) { + if (fieldIndex > 0) { + throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); + } + return wrapperConverter; + } + @Override public void start() {} @Override public void end() {} + + @Override + int getFieldCount() { + return 1; + } } - final class MapConverter extends GroupConverter { + final class MapConverter extends ProtoGroupConverter { private final Converter converter; public MapConverter( @@ -925,5 +989,10 @@ public void start() {} @Override public void end() {} + + @Override + int getFieldCount() { + return 1; + } } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java index 7ac40cd9a8..5d177e2526 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java @@ -26,6 +26,7 @@ import org.apache.parquet.conf.ParquetConfiguration; import org.apache.parquet.hadoop.api.InitContext; import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.schema.MessageType; import org.slf4j.Logger; @@ -41,6 +42,10 @@ public class ProtoReadSupport extends ReadSupport { public static final String PB_DESCRIPTOR = "parquet.proto.descriptor"; public static final String PB_CODEGEN = "parquet.proto.readCodegen"; + abstract static class ProtoGroupConverter extends GroupConverter { + abstract int getFieldCount(); + } + public enum CodegenMode { OFF { @Override diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordConverter.java index d16d085270..428a519e74 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordConverter.java @@ -41,7 +41,7 @@ public class ProtoRecordConverter extends ProtoMessa /** * We dont need to write message value at top level. */ - private static class SkipParentValueContainer extends ParentValueContainer { + static class SkipParentValueContainer extends ParentValueContainer { @Override public void add(Object a) { throw new RuntimeException("Should never happen"); From ee54e895f73dc846e6ed15b0bd4a0f7d314acf95 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Fri, 25 Apr 2025 23:52:02 +0100 Subject: [PATCH 08/15] print converters structure --- .../parquet/proto/ByteBuddyCodeGen.java | 53 ++- .../parquet/proto/ProtoMessageConverter.java | 324 +++++++++++++++--- .../parquet/proto/ProtoReadSupport.java | 5 - .../proto/ProtoRecordMaterializer.java | 9 + 4 files changed, 336 insertions(+), 55 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 0b26d544fa..c84d46e714 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -59,12 +59,12 @@ import java.util.Objects; import java.util.Optional; import java.util.Queue; +import java.util.Stack; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.stream.Stream; -import com.twitter.elephantbird.util.Protobufs; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; import net.bytebuddy.description.method.MethodDescription; @@ -98,11 +98,14 @@ import net.bytebuddy.utility.JavaConstant; import org.apache.parquet.conf.ParquetConfiguration; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.io.api.Converter; +import org.apache.parquet.io.api.PrimitiveConverter; import org.apache.parquet.io.api.RecordConsumer; import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Codegen; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Implementations; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.LocalVar; +import org.apache.parquet.proto.ProtoRecordMaterializer.ProtoGroupConverter; import org.apache.parquet.schema.MessageType; public class ByteBuddyCodeGen { @@ -2755,8 +2758,56 @@ static RecordMaterializer tryEnhanceRecordMaterializer( ProtoReadSupport.CodegenMode codegenMode, ParquetConfiguration configuration) { + visitConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); + return protoRecordMaterializer; } + static void visitConverters(Converter converter, Stack parentConverters) { + StringBuilder indentBuilder = new StringBuilder(); + for (int i = 0; i < parentConverters.size(); i++) { + indentBuilder.append(" "); + } + String indent = indentBuilder.toString(); + + String parentValueContainerInfo = ""; + if (converter instanceof ProtoRecordMaterializer.ParentValueContainerHolder) { + ProtoRecordMaterializer.ParentValueContainerHolder holder = + (ProtoRecordMaterializer.ParentValueContainerHolder) converter; + ProtoMessageConverter.ParentValueContainer parentValueContainer = holder.getParentValueContainer(); + if (parentValueContainer instanceof ProtoMessageConverter.SetFieldParentValueContainer) { + ProtoMessageConverter.SetFieldParentValueContainer pvc = (ProtoMessageConverter.SetFieldParentValueContainer) parentValueContainer; + Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); + String containingTypeName = fieldDescriptor.getContainingType().getName(); + String fieldName = fieldDescriptor.getName(); + Message.Builder parent = pvc.getParent(); + parentValueContainerInfo = " : single : " + containingTypeName + "." + fieldName + " : " + (parent != null ? parent.getClass() : "null"); + } else if (parentValueContainer instanceof ProtoMessageConverter.AddRepeatedFieldParentValueContainer) { + ProtoMessageConverter.AddRepeatedFieldParentValueContainer pvc = (ProtoMessageConverter.AddRepeatedFieldParentValueContainer) parentValueContainer; + Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); + String containingTypeName = fieldDescriptor.getContainingType().getName(); + String fieldName = fieldDescriptor.getName(); + Message.Builder parent = pvc.getParent(); + parentValueContainerInfo = " : repeated : " + containingTypeName + "." + fieldName + " : " + (parent != null ? parent.getClass() : "null"); + } + } + + if (converter instanceof ProtoGroupConverter) { + System.out.println(indent + "ProtoGroupConverter: " + converter.getClass() + parentValueContainerInfo); + ProtoGroupConverter groupConverter = + (ProtoGroupConverter) converter; + for (int i = 0; i < groupConverter.getFieldCount(); i++) { + Converter fieldConverter = groupConverter.getConverter(i); + parentConverters.push(groupConverter); + visitConverters(fieldConverter, parentConverters); + parentConverters.pop(); + } + } else if (converter instanceof PrimitiveConverter) { + System.out.println(indent + "PrimitiveConverter: " + converter.getClass() + parentValueContainerInfo); + } else { + System.out.println(indent + "GroupConverter: " + converter.getClass() + parentValueContainerInfo); + } + } + } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 73696dc268..4b52b8b6a3 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -59,9 +59,9 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.Converter; -import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; -import org.apache.parquet.proto.ProtoReadSupport.ProtoGroupConverter; +import org.apache.parquet.proto.ProtoRecordMaterializer.ParentValueContainerHolder; +import org.apache.parquet.proto.ProtoRecordMaterializer.ProtoGroupConverter; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.IncompatibleSchemaModificationException; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -74,14 +74,14 @@ * Converts Protocol Buffer message (both top level and inner) to parquet. * This is internal class, use {@link ProtoRecordConverter}. */ -class ProtoMessageConverter extends ProtoGroupConverter { +class ProtoMessageConverter extends ProtoGroupConverter implements ParentValueContainerHolder { private static final Logger LOG = LoggerFactory.getLogger(ProtoMessageConverter.class); private static final ParentValueContainer DUMMY_PVC = new DummyParentValueContainer(); protected final ParquetConfiguration conf; protected final Converter[] converters; - protected final ParentValueContainer parent; + protected ParentValueContainer parent; protected final Message.Builder myBuilder; protected final Map extraMetadata; @@ -354,6 +354,16 @@ int getFieldCount() { return converters.length; } + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } + abstract static class ParentValueContainer { /** @@ -402,6 +412,14 @@ public SetFieldParentValueContainer(Message.Builder parent, Descriptors.FieldDes public void add(Object value) { parent.setField(fieldDescriptor, value); } + + public Message.Builder getParent() { + return parent; + } + + public Descriptors.FieldDescriptor getFieldDescriptor() { + return fieldDescriptor; + } } static class AddRepeatedFieldParentValueContainer extends ParentValueContainer { @@ -418,14 +436,22 @@ public AddRepeatedFieldParentValueContainer(Message.Builder parent, Descriptors. public void add(Object value) { parent.addRepeatedField(fieldDescriptor, value); } + + public Message.Builder getParent() { + return parent; + } + + public Descriptors.FieldDescriptor getFieldDescriptor() { + return fieldDescriptor; + } } - final class ProtoEnumConverter extends PrimitiveConverter { + final class ProtoEnumConverter extends PrimitiveConverter implements ParentValueContainerHolder { private final Descriptors.FieldDescriptor fieldType; private final Map enumLookup; private Descriptors.EnumValueDescriptor[] dict; - private final ParentValueContainer parent; + private ParentValueContainer parent; private final Descriptors.EnumDescriptor enumType; private final String unknownEnumPrefix; private final boolean acceptUnknownEnum; @@ -544,11 +570,21 @@ public void setDictionary(Dictionary dictionary) { dict[i] = translateEnumValue(binaryValue); } } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoBinaryConverter extends PrimitiveConverter { + static final class ProtoBinaryConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoBinaryConverter(ParentValueContainer parent) { this.parent = parent; @@ -559,11 +595,21 @@ public void addBinary(Binary binary) { ByteString byteString = ByteString.copyFrom(binary.toByteBuffer()); parent.add(byteString); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoBooleanConverter extends PrimitiveConverter { + static final class ProtoBooleanConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoBooleanConverter(ParentValueContainer parent) { this.parent = parent; @@ -571,13 +617,23 @@ public ProtoBooleanConverter(ParentValueContainer parent) { @Override public void addBoolean(boolean value) { - parent.add(value); + parent.addBoolean(value); + } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; } } - static final class ProtoDoubleConverter extends PrimitiveConverter { + static final class ProtoDoubleConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoDoubleConverter(ParentValueContainer parent) { this.parent = parent; @@ -585,13 +641,23 @@ public ProtoDoubleConverter(ParentValueContainer parent) { @Override public void addDouble(double value) { - parent.add(value); + parent.addDouble(value); + } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; } } - static final class ProtoFloatConverter extends PrimitiveConverter { + static final class ProtoFloatConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoFloatConverter(ParentValueContainer parent) { this.parent = parent; @@ -599,13 +665,23 @@ public ProtoFloatConverter(ParentValueContainer parent) { @Override public void addFloat(float value) { - parent.add(value); + parent.addFloat(value); + } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; } } - static final class ProtoIntConverter extends PrimitiveConverter { + static final class ProtoIntConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoIntConverter(ParentValueContainer parent) { this.parent = parent; @@ -613,13 +689,23 @@ public ProtoIntConverter(ParentValueContainer parent) { @Override public void addInt(int value) { - parent.add(value); + parent.addInt(value); + } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; } } - static final class ProtoLongConverter extends PrimitiveConverter { + static final class ProtoLongConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoLongConverter(ParentValueContainer parent) { this.parent = parent; @@ -627,13 +713,23 @@ public ProtoLongConverter(ParentValueContainer parent) { @Override public void addLong(long value) { - parent.add(value); + parent.addLong(value); + } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; } } - static final class ProtoStringConverter extends PrimitiveConverter { + static final class ProtoStringConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoStringConverter(ParentValueContainer parent) { this.parent = parent; @@ -644,11 +740,21 @@ public void addBinary(Binary binary) { String str = binary.toStringUsingUTF8(); parent.add(str); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoTimestampConverter extends PrimitiveConverter { + static final class ProtoTimestampConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; final LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation; public ProtoTimestampConverter( @@ -672,11 +778,21 @@ public void addLong(long value) { break; } } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoDateConverter extends PrimitiveConverter { + static final class ProtoDateConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoDateConverter(ParentValueContainer parent) { this.parent = parent; @@ -692,11 +808,21 @@ public void addInt(int value) { .build(); parent.add(date); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoTimeConverter extends PrimitiveConverter { + static final class ProtoTimeConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; final LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation; public ProtoTimeConverter( @@ -734,11 +860,21 @@ public void addLong(long value) { .build(); parent.add(timeOfDay); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoDoubleValueConverter extends PrimitiveConverter { + static final class ProtoDoubleValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoDoubleValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -748,11 +884,21 @@ public ProtoDoubleValueConverter(ParentValueContainer parent) { public void addDouble(double value) { parent.add(DoubleValue.of(value)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoFloatValueConverter extends PrimitiveConverter { + static final class ProtoFloatValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoFloatValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -762,11 +908,21 @@ public ProtoFloatValueConverter(ParentValueContainer parent) { public void addFloat(float value) { parent.add(FloatValue.of(value)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoInt64ValueConverter extends PrimitiveConverter { + static final class ProtoInt64ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoInt64ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -776,11 +932,21 @@ public ProtoInt64ValueConverter(ParentValueContainer parent) { public void addLong(long value) { parent.add(Int64Value.of(value)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoUInt64ValueConverter extends PrimitiveConverter { + static final class ProtoUInt64ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoUInt64ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -790,11 +956,21 @@ public ProtoUInt64ValueConverter(ParentValueContainer parent) { public void addLong(long value) { parent.add(UInt64Value.of(value)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoInt32ValueConverter extends PrimitiveConverter { + static final class ProtoInt32ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoInt32ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -804,11 +980,21 @@ public ProtoInt32ValueConverter(ParentValueContainer parent) { public void addInt(int value) { parent.add(Int32Value.of(value)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoUInt32ValueConverter extends PrimitiveConverter { + static final class ProtoUInt32ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoUInt32ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -818,11 +1004,21 @@ public ProtoUInt32ValueConverter(ParentValueContainer parent) { public void addLong(long value) { parent.add(UInt32Value.of(Math.toIntExact(value))); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoBoolValueConverter extends PrimitiveConverter { + static final class ProtoBoolValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoBoolValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -832,11 +1028,21 @@ public ProtoBoolValueConverter(ParentValueContainer parent) { public void addBoolean(boolean value) { parent.add(BoolValue.of(value)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoStringValueConverter extends PrimitiveConverter { + static final class ProtoStringValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoStringValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -847,11 +1053,21 @@ public void addBinary(Binary binary) { String str = binary.toStringUsingUTF8(); parent.add(StringValue.of(str)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } - static final class ProtoBytesValueConverter extends PrimitiveConverter { + static final class ProtoBytesValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { - final ParentValueContainer parent; + ParentValueContainer parent; public ProtoBytesValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -862,6 +1078,16 @@ public void addBinary(Binary binary) { ByteString byteString = ByteString.copyFrom(binary.toByteBuffer()); parent.add(BytesValue.of(byteString)); } + + @Override + public ParentValueContainer getParentValueContainer() { + return parent; + } + + @Override + public void setParentValueContainer(ParentValueContainer parentValueContainer) { + parent = parentValueContainer; + } } /** diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java index 5d177e2526..7ac40cd9a8 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java @@ -26,7 +26,6 @@ import org.apache.parquet.conf.ParquetConfiguration; import org.apache.parquet.hadoop.api.InitContext; import org.apache.parquet.hadoop.api.ReadSupport; -import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.schema.MessageType; import org.slf4j.Logger; @@ -42,10 +41,6 @@ public class ProtoReadSupport extends ReadSupport { public static final String PB_DESCRIPTOR = "parquet.proto.descriptor"; public static final String PB_CODEGEN = "parquet.proto.readCodegen"; - abstract static class ProtoGroupConverter extends GroupConverter { - abstract int getFieldCount(); - } - public enum CodegenMode { OFF { @Override diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java index 12386a4669..704f644fc4 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java @@ -57,4 +57,13 @@ public T getCurrentRecord() { public GroupConverter getRootConverter() { return root; } + + interface ParentValueContainerHolder { + ProtoMessageConverter.ParentValueContainer getParentValueContainer(); + void setParentValueContainer(ProtoMessageConverter.ParentValueContainer parentValueContainer); + } + + abstract static class ProtoGroupConverter extends GroupConverter { + abstract int getFieldCount(); + } } From 6172b68b2ddb145ade4ac5350270d6dc89664482 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Sun, 27 Apr 2025 09:42:37 +0100 Subject: [PATCH 09/15] add benchmarks for proto read --- .../benchmarks/ProtoDataGenerator.java | 3 + .../benchmarks/ProtoReadBenchmarks.java | 74 +++++++++++++++++++ .../parquet/benchmarks/ReadBenchmarks.java | 2 +- .../parquet/proto/ByteBuddyCodeGen.java | 24 +++--- .../parquet/proto/ProtoMessageConverter.java | 12 ++- .../parquet/proto/ProtoReadSupport.java | 14 +++- .../proto/ProtoRecordMaterializer.java | 1 + 7 files changed, 108 insertions(+), 22 deletions(-) create mode 100644 parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java index 23eed270df..c06a07fd8d 100644 --- a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java @@ -186,6 +186,9 @@ public void generateData( CompressionCodecName codec, int nRows) throws IOException { + + outFile = outFile.suffix(protoClass.getName()); + if (exists(configuration, outFile)) { System.out.println("File already exists " + outFile); return; diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java new file mode 100644 index 0000000000..17ea033fd3 --- /dev/null +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.parquet.benchmarks; + +import com.google.protobuf.Message; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.hadoop.ParquetReader; +import org.apache.parquet.proto.ProtoParquetReader; +import org.apache.parquet.proto.ProtoReadSupport; +import org.apache.parquet.proto.ProtoWriteSupport; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.infra.Blackhole; +import java.io.IOException; + +import static org.apache.parquet.benchmarks.BenchmarkFiles.configuration; +import static org.openjdk.jmh.annotations.Scope.Thread; + +@State(Thread) +public class ProtoReadBenchmarks extends ReadBenchmarks { + + @Param({"OFF", "REQUIRED"}) + public ProtoReadSupport.CodegenMode codegenMode; + + @Param({"Test30Int32", "Test100Int32", "Test30String", "Test1"}) + public String protoClass; + + private Class messageClass; + private ProtoDataGenerator protoDataGenerator; + + @Setup(Level.Trial) + public void generateFilesForRead() { + try { + messageClass = (Class) Class.forName("org.apache.parquet.benchmarks.Messages$" + protoClass); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + protoDataGenerator = new ProtoDataGenerator<>(messageClass, ProtoWriteSupport.CodegenMode.OFF); + protoDataGenerator.generateAll(); + } + + protected void read(Path parquetFile, int nRows, Blackhole blackhole) throws IOException { + ProtoReadSupport.setCodegenMode(configuration, codegenMode); + ParquetReader reader = ProtoParquetReader.builder(parquetFile.suffix(messageClass.getName())) + .withConf(configuration) + .build(); + for (int i = 0; i < nRows; i++) { + Message.Builder builder = (Message.Builder) reader.read(); + Message message = builder.build(); + blackhole.consume(message); + } + reader.close(); + } +} diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ReadBenchmarks.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ReadBenchmarks.java index 2d6e3a52e3..484d12f2cf 100644 --- a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ReadBenchmarks.java +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ReadBenchmarks.java @@ -45,7 +45,7 @@ @State(Scope.Benchmark) public class ReadBenchmarks { - private void read(Path parquetFile, int nRows, Blackhole blackhole) throws IOException { + protected void read(Path parquetFile, int nRows, Blackhole blackhole) throws IOException { ParquetReader reader = ParquetReader.builder(new GroupReadSupport(), parquetFile) .withConf(configuration) .build(); diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index c84d46e714..a578161333 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -2758,7 +2758,7 @@ static RecordMaterializer tryEnhanceRecordMaterializer( ProtoReadSupport.CodegenMode codegenMode, ParquetConfiguration configuration) { - visitConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); + // visitConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); return protoRecordMaterializer; } @@ -2776,26 +2776,31 @@ static void visitConverters(Converter converter, Stack pare (ProtoRecordMaterializer.ParentValueContainerHolder) converter; ProtoMessageConverter.ParentValueContainer parentValueContainer = holder.getParentValueContainer(); if (parentValueContainer instanceof ProtoMessageConverter.SetFieldParentValueContainer) { - ProtoMessageConverter.SetFieldParentValueContainer pvc = (ProtoMessageConverter.SetFieldParentValueContainer) parentValueContainer; + ProtoMessageConverter.SetFieldParentValueContainer pvc = + (ProtoMessageConverter.SetFieldParentValueContainer) parentValueContainer; Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); - String containingTypeName = fieldDescriptor.getContainingType().getName(); + String containingTypeName = + fieldDescriptor.getContainingType().getName(); String fieldName = fieldDescriptor.getName(); Message.Builder parent = pvc.getParent(); - parentValueContainerInfo = " : single : " + containingTypeName + "." + fieldName + " : " + (parent != null ? parent.getClass() : "null"); + parentValueContainerInfo = " : single : " + containingTypeName + "." + fieldName + " : " + + (parent != null ? parent.getClass() : "null"); } else if (parentValueContainer instanceof ProtoMessageConverter.AddRepeatedFieldParentValueContainer) { - ProtoMessageConverter.AddRepeatedFieldParentValueContainer pvc = (ProtoMessageConverter.AddRepeatedFieldParentValueContainer) parentValueContainer; + ProtoMessageConverter.AddRepeatedFieldParentValueContainer pvc = + (ProtoMessageConverter.AddRepeatedFieldParentValueContainer) parentValueContainer; Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); - String containingTypeName = fieldDescriptor.getContainingType().getName(); + String containingTypeName = + fieldDescriptor.getContainingType().getName(); String fieldName = fieldDescriptor.getName(); Message.Builder parent = pvc.getParent(); - parentValueContainerInfo = " : repeated : " + containingTypeName + "." + fieldName + " : " + (parent != null ? parent.getClass() : "null"); + parentValueContainerInfo = " : repeated : " + containingTypeName + "." + fieldName + " : " + + (parent != null ? parent.getClass() : "null"); } } if (converter instanceof ProtoGroupConverter) { System.out.println(indent + "ProtoGroupConverter: " + converter.getClass() + parentValueContainerInfo); - ProtoGroupConverter groupConverter = - (ProtoGroupConverter) converter; + ProtoGroupConverter groupConverter = (ProtoGroupConverter) converter; for (int i = 0; i < groupConverter.getFieldCount(); i++) { Converter fieldConverter = groupConverter.getConverter(i); parentConverters.push(groupConverter); @@ -2808,6 +2813,5 @@ static void visitConverters(Converter converter, Stack pare System.out.println(indent + "GroupConverter: " + converter.getClass() + parentValueContainerInfo); } } - } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 4b52b8b6a3..c688185581 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -394,8 +394,7 @@ public void addBoolean(boolean value) { static class DummyParentValueContainer extends ParentValueContainer { @Override - public void add(Object value) { - } + public void add(Object value) {} } static class SetFieldParentValueContainer extends ParentValueContainer { @@ -427,7 +426,8 @@ static class AddRepeatedFieldParentValueContainer extends ParentValueContainer { private final Message.Builder parent; private final Descriptors.FieldDescriptor fieldDescriptor; - public AddRepeatedFieldParentValueContainer(Message.Builder parent, Descriptors.FieldDescriptor fieldDescriptor) { + public AddRepeatedFieldParentValueContainer( + Message.Builder parent, Descriptors.FieldDescriptor fieldDescriptor) { this.parent = parent; this.fieldDescriptor = fieldDescriptor; } @@ -1152,12 +1152,10 @@ public Converter getConverter(int fieldIndex) { } @Override - public void start() { - } + public void start() {} @Override - public void end() { - } + public void end() {} }; } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java index 7ac40cd9a8..b4f118c4f0 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoReadSupport.java @@ -92,7 +92,6 @@ public static ProtoReadSupport.CodegenMode orDefault(ProtoReadSupport.CodegenMod public abstract boolean tryCodeGen(Class protoClass); } - public static void setRequestedProjection(Configuration configuration, String requestedProjection) { configuration.set(PB_REQUESTED_PROJECTION, requestedProjection); } @@ -110,6 +109,10 @@ public static void setProtobufClass(Configuration configuration, String protobuf configuration.set(PB_CLASS, protobufClass); } + public static void setCodegenMode(Configuration configuration, ProtoReadSupport.CodegenMode codegenMode) { + configuration.setEnum(PB_CODEGEN, codegenMode); + } + @Override public ReadContext init(InitContext context) { String requestedProjectionString = context.getParquetConfiguration().get(PB_REQUESTED_PROJECTION); @@ -157,11 +160,14 @@ public RecordMaterializer prepareForRead( MessageType requestedSchema = readContext.getRequestedSchema(); Class protobufClass = Protobufs.getProtobufClass(headerProtoClass); - ProtoRecordMaterializer protoRecordMaterializer = new ProtoRecordMaterializer(configuration, requestedSchema, protobufClass, keyValueMetaData); + ProtoRecordMaterializer protoRecordMaterializer = + new ProtoRecordMaterializer(configuration, requestedSchema, protobufClass, keyValueMetaData); - CodegenMode codegenMode = ProtoReadSupport.CodegenMode.valueOf(configuration.get(PB_CODEGEN, CodegenMode.DEFAULT.name())); + CodegenMode codegenMode = + ProtoReadSupport.CodegenMode.valueOf(configuration.get(PB_CODEGEN, CodegenMode.DEFAULT.name())); return codegenMode.tryCodeGen(protobufClass) - ? ByteBuddyCodeGen.ReadSupport.tryEnhanceRecordMaterializer(protoRecordMaterializer, codegenMode, configuration) + ? ByteBuddyCodeGen.ReadSupport.tryEnhanceRecordMaterializer( + protoRecordMaterializer, codegenMode, configuration) : protoRecordMaterializer; } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java index 704f644fc4..3ad532fd4a 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java @@ -60,6 +60,7 @@ public GroupConverter getRootConverter() { interface ParentValueContainerHolder { ProtoMessageConverter.ParentValueContainer getParentValueContainer(); + void setParentValueContainer(ProtoMessageConverter.ParentValueContainer parentValueContainer); } From 6ca062ab17cad5fec482e974c96403f425f16a69 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Sun, 27 Apr 2025 10:09:30 +0100 Subject: [PATCH 10/15] Modifiable PVC Holder and ModifiableGroupConverter --- .../parquet/proto/ByteBuddyCodeGen.java | 27 ++-- .../parquet/proto/ProtoMessageConverter.java | 146 ++++++++++++------ .../proto/ProtoRecordMaterializer.java | 7 +- 3 files changed, 116 insertions(+), 64 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index a578161333..885b0654ca 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -105,7 +105,10 @@ import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Codegen; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Implementations; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.LocalVar; -import org.apache.parquet.proto.ProtoRecordMaterializer.ProtoGroupConverter; +import org.apache.parquet.proto.ProtoMessageConverter.AddRepeatedFieldParentValueContainer; +import org.apache.parquet.proto.ProtoMessageConverter.SetFieldParentValueContainer; +import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableGroupConverter; +import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableParentValueContainerHolder; import org.apache.parquet.schema.MessageType; public class ByteBuddyCodeGen { @@ -2763,7 +2766,7 @@ static RecordMaterializer tryEnhanceRecordMaterializer( return protoRecordMaterializer; } - static void visitConverters(Converter converter, Stack parentConverters) { + static void visitConverters(Converter converter, Stack parentConverters) { StringBuilder indentBuilder = new StringBuilder(); for (int i = 0; i < parentConverters.size(); i++) { indentBuilder.append(" "); @@ -2771,13 +2774,11 @@ static void visitConverters(Converter converter, Stack pare String indent = indentBuilder.toString(); String parentValueContainerInfo = ""; - if (converter instanceof ProtoRecordMaterializer.ParentValueContainerHolder) { - ProtoRecordMaterializer.ParentValueContainerHolder holder = - (ProtoRecordMaterializer.ParentValueContainerHolder) converter; + if (converter instanceof ModifiableParentValueContainerHolder) { + ModifiableParentValueContainerHolder holder = (ModifiableParentValueContainerHolder) converter; ProtoMessageConverter.ParentValueContainer parentValueContainer = holder.getParentValueContainer(); - if (parentValueContainer instanceof ProtoMessageConverter.SetFieldParentValueContainer) { - ProtoMessageConverter.SetFieldParentValueContainer pvc = - (ProtoMessageConverter.SetFieldParentValueContainer) parentValueContainer; + if (parentValueContainer instanceof SetFieldParentValueContainer) { + SetFieldParentValueContainer pvc = (SetFieldParentValueContainer) parentValueContainer; Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); String containingTypeName = fieldDescriptor.getContainingType().getName(); @@ -2785,9 +2786,9 @@ static void visitConverters(Converter converter, Stack pare Message.Builder parent = pvc.getParent(); parentValueContainerInfo = " : single : " + containingTypeName + "." + fieldName + " : " + (parent != null ? parent.getClass() : "null"); - } else if (parentValueContainer instanceof ProtoMessageConverter.AddRepeatedFieldParentValueContainer) { - ProtoMessageConverter.AddRepeatedFieldParentValueContainer pvc = - (ProtoMessageConverter.AddRepeatedFieldParentValueContainer) parentValueContainer; + } else if (parentValueContainer instanceof AddRepeatedFieldParentValueContainer) { + AddRepeatedFieldParentValueContainer pvc = + (AddRepeatedFieldParentValueContainer) parentValueContainer; Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); String containingTypeName = fieldDescriptor.getContainingType().getName(); @@ -2798,9 +2799,9 @@ static void visitConverters(Converter converter, Stack pare } } - if (converter instanceof ProtoGroupConverter) { + if (converter instanceof ModifiableGroupConverter) { System.out.println(indent + "ProtoGroupConverter: " + converter.getClass() + parentValueContainerInfo); - ProtoGroupConverter groupConverter = (ProtoGroupConverter) converter; + ModifiableGroupConverter groupConverter = (ModifiableGroupConverter) converter; for (int i = 0; i < groupConverter.getFieldCount(); i++) { Converter fieldConverter = groupConverter.getConverter(i); parentConverters.push(groupConverter); diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index c688185581..fcf568b884 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -60,8 +60,8 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.Converter; import org.apache.parquet.io.api.PrimitiveConverter; -import org.apache.parquet.proto.ProtoRecordMaterializer.ParentValueContainerHolder; -import org.apache.parquet.proto.ProtoRecordMaterializer.ProtoGroupConverter; +import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableGroupConverter; +import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableParentValueContainerHolder; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.IncompatibleSchemaModificationException; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -74,7 +74,7 @@ * Converts Protocol Buffer message (both top level and inner) to parquet. * This is internal class, use {@link ProtoRecordConverter}. */ -class ProtoMessageConverter extends ProtoGroupConverter implements ParentValueContainerHolder { +class ProtoMessageConverter extends ModifiableGroupConverter implements ModifiableParentValueContainerHolder { private static final Logger LOG = LoggerFactory.getLogger(ProtoMessageConverter.class); private static final ParentValueContainer DUMMY_PVC = new DummyParentValueContainer(); @@ -354,6 +354,11 @@ int getFieldCount() { return converters.length; } + @Override + void setFieldConverter(int fieldIndex, Converter converter) { + converters[fieldIndex] = converter; + } + @Override public ParentValueContainer getParentValueContainer() { return parent; @@ -364,12 +369,14 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { parent = parentValueContainer; } - abstract static class ParentValueContainer { + static class ParentValueContainer { /** * Adds the value to the parent. */ - public abstract void add(Object value); + public void add(Object value) { + throw new UnsupportedOperationException(); + } public void addInt(int value) { add(value); @@ -446,7 +453,7 @@ public Descriptors.FieldDescriptor getFieldDescriptor() { } } - final class ProtoEnumConverter extends PrimitiveConverter implements ParentValueContainerHolder { + final class ProtoEnumConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { private final Descriptors.FieldDescriptor fieldType; private final Map enumLookup; @@ -582,7 +589,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoBinaryConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoBinaryConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -607,7 +614,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoBooleanConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoBooleanConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -631,7 +639,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoDoubleConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoDoubleConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -655,7 +663,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoFloatConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoFloatConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -679,7 +687,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoIntConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoIntConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -703,7 +711,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoLongConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoLongConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -727,7 +735,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoStringConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoStringConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -752,7 +760,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoTimestampConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoTimestampConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; final LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation; @@ -790,7 +799,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoDateConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoDateConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -820,7 +829,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoTimeConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoTimeConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { ParentValueContainer parent; final LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation; @@ -872,7 +881,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoDoubleValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoDoubleValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -896,7 +906,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoFloatValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoFloatValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -920,7 +931,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoInt64ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoInt64ValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -944,7 +956,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoUInt64ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoUInt64ValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -968,7 +981,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoInt32ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoInt32ValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -992,7 +1006,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoUInt32ValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoUInt32ValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -1016,7 +1031,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoBoolValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoBoolValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -1040,7 +1056,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoStringValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoStringValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -1065,7 +1082,8 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { } } - static final class ProtoBytesValueConverter extends PrimitiveConverter implements ParentValueContainerHolder { + static final class ProtoBytesValueConverter extends PrimitiveConverter + implements ModifiableParentValueContainerHolder { ParentValueContainer parent; @@ -1112,9 +1130,40 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { * a repeated group named 'list', itself containing only one field called 'element' of the type of the repeated * object (can be a primitive as in this example or a group in case of a repeated message in protobuf). */ - final class ListConverter extends ProtoGroupConverter { - private final Converter converter; - private final Converter wrapperConverter; + final class ListConverter extends ModifiableGroupConverter { + private Converter converter; + + final class ListWrapperConverter extends ModifiableGroupConverter { + private Converter converter; + + ListWrapperConverter(Converter converter) { + this.converter = converter; + } + + @Override + int getFieldCount() { + return 1; + } + + @Override + void setFieldConverter(int fieldIndex, Converter converter) { + if (fieldIndex > 0) { + throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); + } + this.converter = converter; + } + + @Override + public Converter getConverter(int fieldIndex) { + return converter; + } + + @Override + public void start() {} + + @Override + public void end() {} + } public ListConverter( Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { @@ -1139,24 +1188,7 @@ public ListConverter( } Type elementType = listType.getType("element"); - converter = newMessageConverter(parentBuilder, fieldDescriptor, elementType); - wrapperConverter = new ProtoGroupConverter() { - @Override - int getFieldCount() { - return 1; - } - - @Override - public Converter getConverter(int fieldIndex) { - return converter; - } - - @Override - public void start() {} - - @Override - public void end() {} - }; + converter = new ListWrapperConverter(newMessageConverter(parentBuilder, fieldDescriptor, elementType)); } @Override @@ -1164,7 +1196,7 @@ public Converter getConverter(int fieldIndex) { if (fieldIndex > 0) { throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); } - return wrapperConverter; + return converter; } @Override @@ -1177,10 +1209,18 @@ public void end() {} int getFieldCount() { return 1; } + + @Override + void setFieldConverter(int fieldIndex, Converter converter) { + if (fieldIndex > 0) { + throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); + } + this.converter = converter; + } } - final class MapConverter extends ProtoGroupConverter { - private final Converter converter; + final class MapConverter extends ModifiableGroupConverter { + private Converter converter; public MapConverter( Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { @@ -1218,5 +1258,13 @@ public void end() {} int getFieldCount() { return 1; } + + @Override + void setFieldConverter(int fieldIndex, Converter converter) { + if (fieldIndex > 0) { + throw new ParquetDecodingException("Unexpected multiple fields in the MAP wrapper"); + } + MapConverter.this.converter = converter; + } } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java index 3ad532fd4a..cb6f89b056 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java @@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration; import org.apache.parquet.conf.HadoopParquetConfiguration; import org.apache.parquet.conf.ParquetConfiguration; +import org.apache.parquet.io.api.Converter; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.schema.MessageType; @@ -58,13 +59,15 @@ public GroupConverter getRootConverter() { return root; } - interface ParentValueContainerHolder { + interface ModifiableParentValueContainerHolder { ProtoMessageConverter.ParentValueContainer getParentValueContainer(); void setParentValueContainer(ProtoMessageConverter.ParentValueContainer parentValueContainer); } - abstract static class ProtoGroupConverter extends GroupConverter { + abstract static class ModifiableGroupConverter extends GroupConverter { abstract int getFieldCount(); + + abstract void setFieldConverter(int fieldIndex, Converter converter); } } From fdf5a44dcccbc020e5d118bb556b489557285b47 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Wed, 30 Apr 2025 21:51:17 +0100 Subject: [PATCH 11/15] int32 draft optimization --- .../parquet/proto/ByteBuddyCodeGen.java | 162 +++++++++++++++++- .../parquet/proto/ProtoMessageConverter.java | 2 +- 2 files changed, 161 insertions(+), 3 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 885b0654ca..8ba6bfd4ce 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -29,6 +29,7 @@ import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; +import com.google.protobuf.MapEntry; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.StringValue; @@ -64,6 +65,7 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Stream; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; @@ -2761,12 +2763,168 @@ static RecordMaterializer tryEnhanceRecordMaterializer( ProtoReadSupport.CodegenMode codegenMode, ParquetConfiguration configuration) { + updateConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); // visitConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); return protoRecordMaterializer; } - static void visitConverters(Converter converter, Stack parentConverters) { + private static Converter updateConverters( + Converter converter, Stack parentConverters) { + if (converter instanceof ModifiableGroupConverter) { + ModifiableGroupConverter groupConverter = (ModifiableGroupConverter) converter; + for (int i = 0; i < groupConverter.getFieldCount(); i++) { + Converter fieldConverter = groupConverter.getConverter(i); + parentConverters.push(groupConverter); + Converter newFieldConverter = updateConverters(fieldConverter, parentConverters); + parentConverters.pop(); + groupConverter.setFieldConverter(i, newFieldConverter); + } + } + return updateConverter(converter, parentConverters); + } + + private static Converter updateConverter( + Converter converter, Stack parentConverters) { + if (converter instanceof ProtoRecordConverter) { + return converter; + } + + if (converter instanceof ModifiableParentValueContainerHolder) { + ModifiableParentValueContainerHolder parentValueContainerHolder = + (ModifiableParentValueContainerHolder) converter; + ProtoMessageConverter.ParentValueContainer parentValueContainer = + parentValueContainerHolder.getParentValueContainer(); + + Descriptors.FieldDescriptor fieldDescriptor; + Message.Builder parentBuilder; + boolean repeatedField; + if (parentValueContainer instanceof SetFieldParentValueContainer) { + SetFieldParentValueContainer pvc = (SetFieldParentValueContainer) parentValueContainer; + fieldDescriptor = pvc.getFieldDescriptor(); + parentBuilder = pvc.getParent(); + repeatedField = false; + } else if (parentValueContainer instanceof AddRepeatedFieldParentValueContainer) { + AddRepeatedFieldParentValueContainer pvc = + (AddRepeatedFieldParentValueContainer) parentValueContainer; + fieldDescriptor = pvc.getFieldDescriptor(); + parentBuilder = pvc.getParent(); + repeatedField = true; + } else if (parentValueContainer instanceof ProtoMessageConverter.DummyParentValueContainer) { + return converter; + } else { + throw new IllegalStateException("Unknown parent value container: " + parentValueContainer); + } + + if (converter instanceof PrimitiveConverter) { + return updatePrimitiveConverter( + (PrimitiveConverter) converter, + parentValueContainerHolder, + parentConverters, + fieldDescriptor, + repeatedField, + parentBuilder); + } + } + return converter; + } + + private static Converter updatePrimitiveConverter( + PrimitiveConverter converter, + ModifiableParentValueContainerHolder containerHolder, + Stack parentConverters, + Descriptors.FieldDescriptor fieldDescriptor, + boolean repeatedField, + Message.Builder parentBuilder) { + if (converter instanceof ProtoMessageConverter.ProtoIntConverter + && !repeatedField + && !(parentBuilder instanceof MapEntry.Builder)) { + ProtoMessageConverter.ParentValueContainer newPvc = new Function< + Message.Builder, ProtoMessageConverter.ParentValueContainer>() { + private DynamicType.Builder classBuilder; + + @Override + public ProtoMessageConverter.ParentValueContainer apply(Message.Builder builder) { + Class parentBuilderClass = parentBuilder.getClass(); + Method setterMethod = ReflectionUtil.getDeclaredMethod( + parentBuilderClass, fieldDescriptor, "set{}", int.class); + + classBuilder = new ByteBuddy() + .subclass(ProtoMessageConverter.ParentValueContainer.class) + .modifiers(Visibility.PUBLIC) + .name(ProtoMessageConverter.ParentValueContainer.class.getName() + "$Generated$" + + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); + + TypeDescription.Generic parentBuilderType = + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(parentBuilderClass); + FieldDescription.Latent parentBuilderFieldDesc = new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "parent", Modifier.PRIVATE | Modifier.FINAL, parentBuilderType)); + classBuilder = classBuilder.define(parentBuilderFieldDesc); + + classBuilder = classBuilder + .define(new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + MethodDescription.CONSTRUCTOR_INTERNAL_NAME, + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.singletonList(parentBuilderType)))) + .intercept(MethodCall.invoke(ReflectionUtil.getConstructor( + ProtoMessageConverter.ParentValueContainer.class)) + .andThen(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar parentVar = localVars.register(parentBuilderClass)) { + add( + MethodVariableAccess.loadThis(), + parentVar.load(), + FieldAccess.forField(parentBuilderFieldDesc) + .write()); + } + } + add(Codegen.returnVoid()); + } + })); + + classBuilder = classBuilder + .method(ElementMatchers.named("addInt")) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar valueVar = localVars.register(int.class)) { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(parentBuilderFieldDesc) + .read(), + valueVar.load(), + Codegen.invokeMethod(setterMethod)); + add(Codegen.returnVoid()); + } + } + } + }); + + DynamicType.Unloaded unloaded = classBuilder.make(); + Class pvcClass = unloaded.load( + this.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) + .getLoaded(); + return ReflectionUtil.newInstance( + ReflectionUtil.getConstructor(pvcClass, parentBuilderClass), parentBuilder); + } + }.apply(parentBuilder); + + containerHolder.setParentValueContainer(newPvc); + } + return converter; + } + + static void printConvertersTree(Converter converter, Stack parentConverters) { StringBuilder indentBuilder = new StringBuilder(); for (int i = 0; i < parentConverters.size(); i++) { indentBuilder.append(" "); @@ -2805,7 +2963,7 @@ static void visitConverters(Converter converter, Stack for (int i = 0; i < groupConverter.getFieldCount(); i++) { Converter fieldConverter = groupConverter.getConverter(i); parentConverters.push(groupConverter); - visitConverters(fieldConverter, parentConverters); + printConvertersTree(fieldConverter, parentConverters); parentConverters.pop(); } } else if (converter instanceof PrimitiveConverter) { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index fcf568b884..6a3a892252 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -369,7 +369,7 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { parent = parentValueContainer; } - static class ParentValueContainer { + public static class ParentValueContainer { /** * Adds the value to the parent. From 3526bccc73de4b9f662794893a5d0ed8597456dd Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Fri, 2 May 2025 22:51:06 +0100 Subject: [PATCH 12/15] before generating pvc --- .../parquet/proto/ByteBuddyCodeGen.java | 650 +++++++++++++----- .../parquet/proto/ProtoMessageConverter.java | 379 ++-------- .../proto/ProtoRecordMaterializer.java | 12 - 3 files changed, 514 insertions(+), 527 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 8ba6bfd4ce..2da5f618e8 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -60,12 +60,10 @@ import java.util.Objects; import java.util.Optional; import java.util.Queue; -import java.util.Stack; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; -import java.util.function.Function; import java.util.stream.Stream; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; @@ -101,16 +99,37 @@ import org.apache.parquet.conf.ParquetConfiguration; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.Converter; -import org.apache.parquet.io.api.PrimitiveConverter; +import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.RecordConsumer; import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Codegen; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.Implementations; import org.apache.parquet.proto.ByteBuddyCodeGen.CodeGenUtils.LocalVar; import org.apache.parquet.proto.ProtoMessageConverter.AddRepeatedFieldParentValueContainer; +import org.apache.parquet.proto.ProtoMessageConverter.ListConverter; +import org.apache.parquet.proto.ProtoMessageConverter.MapConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ParentValueContainer; import org.apache.parquet.proto.ProtoMessageConverter.SetFieldParentValueContainer; -import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableGroupConverter; -import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableParentValueContainerHolder; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoEnumConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoBinaryConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoBooleanConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoDoubleConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoFloatConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoIntConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoLongConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoStringConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoTimestampConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoDateConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoTimeConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoDoubleValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoFloatValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoInt64ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoUInt64ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoInt32ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoUInt32ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoBoolValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoStringValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoBytesValueConverter; import org.apache.parquet.schema.MessageType; public class ByteBuddyCodeGen { @@ -2763,213 +2782,464 @@ static RecordMaterializer tryEnhanceRecordMaterializer( ProtoReadSupport.CodegenMode codegenMode, ParquetConfiguration configuration) { - updateConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); - // visitConverters(protoRecordMaterializer.getRootConverter(), new Stack<>()); + protoRecordMaterializer = new ProtoRecordMaterializerTransformer(codegenMode, configuration) + .transform(protoRecordMaterializer); return protoRecordMaterializer; } - private static Converter updateConverters( - Converter converter, Stack parentConverters) { - if (converter instanceof ModifiableGroupConverter) { - ModifiableGroupConverter groupConverter = (ModifiableGroupConverter) converter; - for (int i = 0; i < groupConverter.getFieldCount(); i++) { - Converter fieldConverter = groupConverter.getConverter(i); - parentConverters.push(groupConverter); - Converter newFieldConverter = updateConverters(fieldConverter, parentConverters); - parentConverters.pop(); - groupConverter.setFieldConverter(i, newFieldConverter); + private static class ProtoRecordMaterializerTransformer { + + interface MapEntryBuilder { + void clear(); + } + + private ProtoReadSupport.CodegenMode codegenMode; + private ParquetConfiguration configuration; + + public ProtoRecordMaterializerTransformer(ProtoReadSupport.CodegenMode codegenMode, + ParquetConfiguration configuration) { + this.codegenMode = codegenMode; + this.configuration = configuration; + } + + public ProtoRecordMaterializer transform(ProtoRecordMaterializer protoRecordMaterializer) { + GroupConverter rootConverter = protoRecordMaterializer.getRootConverter(); + if (rootConverter instanceof ProtoMessageConverter) { + ProtoMessageConverter protoMessageConverter = (ProtoMessageConverter) rootConverter; + transformRootConverter(protoMessageConverter); } + return protoRecordMaterializer; } - return updateConverter(converter, parentConverters); - } - private static Converter updateConverter( - Converter converter, Stack parentConverters) { - if (converter instanceof ProtoRecordConverter) { + private void transformRootConverter(ProtoMessageConverter messageConverter) { + Converter[] converters = messageConverter.converters; + for (int i = 0; i < converters.length; i++) { + Converter converter = converters[i]; + converters[i] = transformChildConverter(messageConverter.myBuilder, converter); + } + } + + private Converter transformChildConverter(Object parentBuilder, Converter converter) { + if (converter instanceof ProtoMessageConverter) { + return transformChildConverterProtoMessageConverter( + parentBuilder, (ProtoMessageConverter) converter); + } + if (converter instanceof ProtoEnumConverter) { + return transformChildConverterProtoEnumConverter(parentBuilder, (ProtoEnumConverter) converter); + } + if (converter instanceof ProtoBinaryConverter) { + return transformChildConverterProtoBinaryConverter(parentBuilder, (ProtoBinaryConverter) converter); + } + if (converter instanceof ProtoBooleanConverter) { + return transformChildConverterProtoBooleanConverter(parentBuilder, (ProtoBooleanConverter) converter); + } + if (converter instanceof ProtoDoubleConverter) { + return transformChildConverterProtoDoubleConverter(parentBuilder, (ProtoDoubleConverter) converter); + } + if (converter instanceof ProtoFloatConverter) { + return transformChildConverterProtoFloatConverter(parentBuilder, (ProtoFloatConverter) converter); + } + if (converter instanceof ProtoIntConverter) { + return transformChildConverterProtoIntConverter(parentBuilder, (ProtoIntConverter) converter); + } + if (converter instanceof ProtoLongConverter) { + return transformChildConverterProtoLongConverter(parentBuilder, (ProtoLongConverter) converter); + } + if (converter instanceof ProtoStringConverter) { + return transformChildConverterProtoStringConverter(parentBuilder, (ProtoStringConverter) converter); + } + if (converter instanceof ProtoTimestampConverter) { + return transformChildConverterProtoTimestampConverter(parentBuilder, (ProtoTimestampConverter) converter); + } + if (converter instanceof ProtoDateConverter) { + return transformChildConverterProtoDateConverter(parentBuilder, (ProtoDateConverter) converter); + } + if (converter instanceof ProtoTimeConverter) { + return transformChildConverterProtoTimeConverter(parentBuilder, (ProtoTimeConverter) converter); + } + if (converter instanceof ProtoDoubleValueConverter) { + return transformChildConverterProtoDoubleValueConverter(parentBuilder, (ProtoDoubleValueConverter) converter); + } + if (converter instanceof ProtoFloatValueConverter) { + return transformChildConverterProtoFloatValueConverter(parentBuilder, (ProtoFloatValueConverter) converter); + } + if (converter instanceof ProtoInt64ValueConverter) { + return transformChildConverterProtoInt64ValueConverter(parentBuilder, (ProtoInt64ValueConverter) converter); + } + if (converter instanceof ProtoUInt64ValueConverter) { + return transformChildConverterProtoUInt64ValueConverter(parentBuilder, (ProtoUInt64ValueConverter) converter); + } + if (converter instanceof ProtoInt32ValueConverter) { + return transformChildConverterProtoInt32ValueConverter(parentBuilder, (ProtoInt32ValueConverter) converter); + } + if (converter instanceof ProtoUInt32ValueConverter) { + return transformChildConverterProtoUInt32ValueConverter(parentBuilder, (ProtoUInt32ValueConverter) converter); + } + if (converter instanceof ProtoBoolValueConverter) { + return transformChildConverterProtoBoolValueConverter(parentBuilder, (ProtoBoolValueConverter) converter); + } + if (converter instanceof ProtoStringValueConverter) { + return transformChildConverterProtoStringValueConverter(parentBuilder, (ProtoStringValueConverter) converter); + } + if (converter instanceof ProtoBytesValueConverter) { + return transformChildConverterProtoBytesValueConverter(parentBuilder, (ProtoBytesValueConverter) converter); + } + if (converter instanceof MapConverter) { + return transformChildConverterMapConverter(parentBuilder, (MapConverter) converter); + } + if (converter instanceof ListConverter) { + return transformChildConverterListConverter(parentBuilder, (ListConverter) converter); + } return converter; } - if (converter instanceof ModifiableParentValueContainerHolder) { - ModifiableParentValueContainerHolder parentValueContainerHolder = - (ModifiableParentValueContainerHolder) converter; - ProtoMessageConverter.ParentValueContainer parentValueContainer = - parentValueContainerHolder.getParentValueContainer(); - - Descriptors.FieldDescriptor fieldDescriptor; - Message.Builder parentBuilder; - boolean repeatedField; - if (parentValueContainer instanceof SetFieldParentValueContainer) { - SetFieldParentValueContainer pvc = (SetFieldParentValueContainer) parentValueContainer; - fieldDescriptor = pvc.getFieldDescriptor(); - parentBuilder = pvc.getParent(); - repeatedField = false; - } else if (parentValueContainer instanceof AddRepeatedFieldParentValueContainer) { - AddRepeatedFieldParentValueContainer pvc = - (AddRepeatedFieldParentValueContainer) parentValueContainer; - fieldDescriptor = pvc.getFieldDescriptor(); - parentBuilder = pvc.getParent(); - repeatedField = true; - } else if (parentValueContainer instanceof ProtoMessageConverter.DummyParentValueContainer) { + private Converter transformChildConverterProtoEnumConverter(Object parentBuilder, ProtoEnumConverter converter) { + return converter; + } + + private Converter transformChildConverterListConverter(Object parentBuilder, ListConverter converter) { + GroupConverter wrapperConverter = new GroupConverter() { + @Override + public Converter getConverter(int fieldIndex) { + return transformChildConverter(parentBuilder, converter.converter.converter); + } + + @Override + public void start() { + } + + @Override + public void end() { + } + }; + + return new GroupConverter() { + @Override + public Converter getConverter(int fieldIndex) { + return wrapperConverter; + } + + @Override + public void start() { + } + + @Override + public void end() { + } + }; + } + + private Converter transformChildConverterMapConverter(Object parentBuilder, MapConverter converter) { + return new GroupConverter() { + @Override + public Converter getConverter(int fieldIndex) { + return transformChildConverter(parentBuilder, converter.converter); + } + + @Override + public void start() { + } + + @Override + public void end() { + } + }; + } + + private Converter transformChildConverterProtoBytesValueConverter(Object parentBuilder, + ProtoBytesValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { return converter; - } else { - throw new IllegalStateException("Unknown parent value container: " + parentValueContainer); } - if (converter instanceof PrimitiveConverter) { - return updatePrimitiveConverter( - (PrimitiveConverter) converter, - parentValueContainerHolder, - parentConverters, - fieldDescriptor, - repeatedField, - parentBuilder); + return new ProtoBytesValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoStringValueConverter(Object parentBuilder, + ProtoStringValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; } + + return new ProtoStringValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); } - return converter; - } - private static Converter updatePrimitiveConverter( - PrimitiveConverter converter, - ModifiableParentValueContainerHolder containerHolder, - Stack parentConverters, - Descriptors.FieldDescriptor fieldDescriptor, - boolean repeatedField, - Message.Builder parentBuilder) { - if (converter instanceof ProtoMessageConverter.ProtoIntConverter - && !repeatedField - && !(parentBuilder instanceof MapEntry.Builder)) { - ProtoMessageConverter.ParentValueContainer newPvc = new Function< - Message.Builder, ProtoMessageConverter.ParentValueContainer>() { - private DynamicType.Builder classBuilder; + private Converter transformChildConverterProtoBoolValueConverter(Object parentBuilder, + ProtoBoolValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); - @Override - public ProtoMessageConverter.ParentValueContainer apply(Message.Builder builder) { - Class parentBuilderClass = parentBuilder.getClass(); - Method setterMethod = ReflectionUtil.getDeclaredMethod( - parentBuilderClass, fieldDescriptor, "set{}", int.class); - - classBuilder = new ByteBuddy() - .subclass(ProtoMessageConverter.ParentValueContainer.class) - .modifiers(Visibility.PUBLIC) - .name(ProtoMessageConverter.ParentValueContainer.class.getName() + "$Generated$" - + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); - - TypeDescription.Generic parentBuilderType = - TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(parentBuilderClass); - FieldDescription.Latent parentBuilderFieldDesc = new FieldDescription.Latent( - classBuilder.toTypeDescription(), - new FieldDescription.Token( - "parent", Modifier.PRIVATE | Modifier.FINAL, parentBuilderType)); - classBuilder = classBuilder.define(parentBuilderFieldDesc); - - classBuilder = classBuilder - .define(new MethodDescription.Latent( - classBuilder.toTypeDescription(), - new MethodDescription.Token( - MethodDescription.CONSTRUCTOR_INTERNAL_NAME, - Visibility.PUBLIC.getMask(), - TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), - Collections.singletonList(parentBuilderType)))) - .intercept(MethodCall.invoke(ReflectionUtil.getConstructor( - ProtoMessageConverter.ParentValueContainer.class)) - .andThen(new Implementations() { - { - CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); - try (LocalVar thisLocalVar = - localVars.register(classBuilder.toTypeDescription())) { - try (LocalVar parentVar = localVars.register(parentBuilderClass)) { - add( - MethodVariableAccess.loadThis(), - parentVar.load(), - FieldAccess.forField(parentBuilderFieldDesc) - .write()); - } - } - add(Codegen.returnVoid()); - } - })); - - classBuilder = classBuilder - .method(ElementMatchers.named("addInt")) - .intercept(new Implementations() { - { - CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); - try (LocalVar thisLocalVar = - localVars.register(classBuilder.toTypeDescription())) { - try (LocalVar valueVar = localVars.register(int.class)) { - add( - MethodVariableAccess.loadThis(), - FieldAccess.forField(parentBuilderFieldDesc) - .read(), - valueVar.load(), - Codegen.invokeMethod(setterMethod)); - add(Codegen.returnVoid()); - } - } - } - }); + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoBoolValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoUInt32ValueConverter(Object parentBuilder, + ProtoUInt32ValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoUInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoInt32ValueConverter(Object parentBuilder, + ProtoInt32ValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoUInt64ValueConverter(Object parentBuilder, + ProtoUInt64ValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoUInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoInt64ValueConverter(Object parentBuilder, + ProtoInt64ValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } - DynamicType.Unloaded unloaded = classBuilder.make(); - Class pvcClass = unloaded.load( - this.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) - .getLoaded(); - return ReflectionUtil.newInstance( - ReflectionUtil.getConstructor(pvcClass, parentBuilderClass), parentBuilder); + return new ProtoInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoFloatValueConverter(Object parentBuilder, + ProtoFloatValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoFloatValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoDoubleValueConverter(Object parentBuilder, + ProtoDoubleValueConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoDoubleValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoTimeConverter(Object parentBuilder, + ProtoTimeConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoTimeConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent), converter.logicalTypeAnnotation); + } + + private Converter transformChildConverterProtoDateConverter(Object parentBuilder, + ProtoDateConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoDateConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoTimestampConverter(Object parentBuilder, + ProtoTimestampConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoTimestampConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent), converter.logicalTypeAnnotation); + } + + private Converter transformChildConverterProtoBinaryConverter(Object parentBuilder, + ProtoBinaryConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoBinaryConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoStringConverter(Object parentBuilder, + ProtoStringConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoStringConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoLongConverter(Object parentBuilder, + ProtoLongConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoLongConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoIntConverter(Object parentBuilder, + ProtoIntConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoIntConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoFloatConverter(Object parentBuilder, + ProtoFloatConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoFloatConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoDoubleConverter(Object parentBuilder, + ProtoDoubleConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoDoubleConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoBooleanConverter(Object parentBuilder, + ProtoBooleanConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoBooleanConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + } + + private Converter transformChildConverterProtoMessageConverter(Object parentBuilder, + ProtoMessageConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + Object myBuilder = fieldDescriptor.isMapField() + // TODO + ? converter.myBuilder.build().toBuilder()//newMapEntryBuilder(parentBuilder, fieldDescriptor) + : converter.myBuilder; + + Converter[] converters = converter.converters; + Converter[] newConverters = new Converter[converters.length]; + for (int i = 0; i < converters.length; i++) { + Converter childConverter = converters[i]; + newConverters[i] = transformChildConverter(converter, childConverter); + } + + ParentValueContainer newPvc = generatePvc(parentBuilder, fieldDescriptor, converter.parent); + + return new PreBuiltProtoMessageConverter(newConverters, newPvc, myBuilder); + } + + private Object newMapEntryBuilder(Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor) { + Message.Builder messageBuilder = (Message.Builder) parentBuilder; + return messageBuilder.newBuilderForField(fieldDescriptor); + } + + private ParentValueContainer generatePvc(Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, + ParentValueContainer fallbackPvc) { + return new ParentValueContainer() { + @Override + public void add(Object val) { + if (val instanceof Message.Builder) { + Message.Builder builder = (Message.Builder) val; + Message message = builder.build(); + fallbackPvc.add(message); + builder.clear(); + } else { + fallbackPvc.add(val); + } } - }.apply(parentBuilder); + }; + } - containerHolder.setParentValueContainer(newPvc); + private Descriptors.FieldDescriptor getFieldDescriptor(ParentValueContainer pvc) { + if (pvc instanceof SetFieldParentValueContainer) { + SetFieldParentValueContainer setFieldPvc = (SetFieldParentValueContainer) pvc; + return setFieldPvc.getFieldDescriptor(); + } + if (pvc instanceof AddRepeatedFieldParentValueContainer) { + AddRepeatedFieldParentValueContainer addRepeatedFieldPvc = (AddRepeatedFieldParentValueContainer) pvc; + return addRepeatedFieldPvc.getFieldDescriptor(); + } + return null; } - return converter; - } - static void printConvertersTree(Converter converter, Stack parentConverters) { - StringBuilder indentBuilder = new StringBuilder(); - for (int i = 0; i < parentConverters.size(); i++) { - indentBuilder.append(" "); - } - String indent = indentBuilder.toString(); - - String parentValueContainerInfo = ""; - if (converter instanceof ModifiableParentValueContainerHolder) { - ModifiableParentValueContainerHolder holder = (ModifiableParentValueContainerHolder) converter; - ProtoMessageConverter.ParentValueContainer parentValueContainer = holder.getParentValueContainer(); - if (parentValueContainer instanceof SetFieldParentValueContainer) { - SetFieldParentValueContainer pvc = (SetFieldParentValueContainer) parentValueContainer; - Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); - String containingTypeName = - fieldDescriptor.getContainingType().getName(); - String fieldName = fieldDescriptor.getName(); - Message.Builder parent = pvc.getParent(); - parentValueContainerInfo = " : single : " + containingTypeName + "." + fieldName + " : " - + (parent != null ? parent.getClass() : "null"); - } else if (parentValueContainer instanceof AddRepeatedFieldParentValueContainer) { - AddRepeatedFieldParentValueContainer pvc = - (AddRepeatedFieldParentValueContainer) parentValueContainer; - Descriptors.FieldDescriptor fieldDescriptor = pvc.getFieldDescriptor(); - String containingTypeName = - fieldDescriptor.getContainingType().getName(); - String fieldName = fieldDescriptor.getName(); - Message.Builder parent = pvc.getParent(); - parentValueContainerInfo = " : repeated : " + containingTypeName + "." + fieldName + " : " - + (parent != null ? parent.getClass() : "null"); - } - } - - if (converter instanceof ModifiableGroupConverter) { - System.out.println(indent + "ProtoGroupConverter: " + converter.getClass() + parentValueContainerInfo); - ModifiableGroupConverter groupConverter = (ModifiableGroupConverter) converter; - for (int i = 0; i < groupConverter.getFieldCount(); i++) { - Converter fieldConverter = groupConverter.getConverter(i); - parentConverters.push(groupConverter); - printConvertersTree(fieldConverter, parentConverters); - parentConverters.pop(); - } - } else if (converter instanceof PrimitiveConverter) { - System.out.println(indent + "PrimitiveConverter: " + converter.getClass() + parentValueContainerInfo); - } else { - System.out.println(indent + "GroupConverter: " + converter.getClass() + parentValueContainerInfo); + private static class PreBuiltProtoMessageConverter extends GroupConverter { + protected final Converter[] converters; + protected final ParentValueContainer parent; + protected final Object myBuilder; + + public PreBuiltProtoMessageConverter(Converter[] converters, ParentValueContainer parent, + Object myBuilder) { + this.converters = converters; + this.parent = parent; + this.myBuilder = myBuilder; + } + + @Override + public Converter getConverter(int fieldIndex) { + return converters[fieldIndex]; + } + + @Override + public void start() { + } + + @Override + public void end() { + parent.add(myBuilder); + } } } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 6a3a892252..230049a5c4 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -59,9 +59,8 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; import org.apache.parquet.io.api.Converter; +import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.PrimitiveConverter; -import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableGroupConverter; -import org.apache.parquet.proto.ProtoRecordMaterializer.ModifiableParentValueContainerHolder; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.IncompatibleSchemaModificationException; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -74,14 +73,14 @@ * Converts Protocol Buffer message (both top level and inner) to parquet. * This is internal class, use {@link ProtoRecordConverter}. */ -class ProtoMessageConverter extends ModifiableGroupConverter implements ModifiableParentValueContainerHolder { +class ProtoMessageConverter extends GroupConverter { private static final Logger LOG = LoggerFactory.getLogger(ProtoMessageConverter.class); private static final ParentValueContainer DUMMY_PVC = new DummyParentValueContainer(); protected final ParquetConfiguration conf; protected final Converter[] converters; - protected ParentValueContainer parent; + protected final ParentValueContainer parent; protected final Message.Builder myBuilder; protected final Map extraMetadata; @@ -349,26 +348,6 @@ public Message.Builder getBuilder() { return myBuilder; } - @Override - int getFieldCount() { - return converters.length; - } - - @Override - void setFieldConverter(int fieldIndex, Converter converter) { - converters[fieldIndex] = converter; - } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } - public static class ParentValueContainer { /** @@ -453,12 +432,12 @@ public Descriptors.FieldDescriptor getFieldDescriptor() { } } - final class ProtoEnumConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + final class ProtoEnumConverter extends PrimitiveConverter { private final Descriptors.FieldDescriptor fieldType; private final Map enumLookup; private Descriptors.EnumValueDescriptor[] dict; - private ParentValueContainer parent; + final ParentValueContainer parent; private final Descriptors.EnumDescriptor enumType; private final String unknownEnumPrefix; private final boolean acceptUnknownEnum; @@ -577,21 +556,11 @@ public void setDictionary(Dictionary dictionary) { dict[i] = translateEnumValue(binaryValue); } } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoBinaryConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoBinaryConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoBinaryConverter(ParentValueContainer parent) { this.parent = parent; @@ -602,22 +571,11 @@ public void addBinary(Binary binary) { ByteString byteString = ByteString.copyFrom(binary.toByteBuffer()); parent.add(byteString); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoBooleanConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoBooleanConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoBooleanConverter(ParentValueContainer parent) { this.parent = parent; @@ -625,23 +583,13 @@ public ProtoBooleanConverter(ParentValueContainer parent) { @Override public void addBoolean(boolean value) { - parent.addBoolean(value); - } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; + parent.add(value); } } - static final class ProtoDoubleConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoDoubleConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoDoubleConverter(ParentValueContainer parent) { this.parent = parent; @@ -649,23 +597,13 @@ public ProtoDoubleConverter(ParentValueContainer parent) { @Override public void addDouble(double value) { - parent.addDouble(value); - } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; + parent.add(value); } } - static final class ProtoFloatConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoFloatConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoFloatConverter(ParentValueContainer parent) { this.parent = parent; @@ -673,23 +611,13 @@ public ProtoFloatConverter(ParentValueContainer parent) { @Override public void addFloat(float value) { - parent.addFloat(value); - } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; + parent.add(value); } } - static final class ProtoIntConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoIntConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoIntConverter(ParentValueContainer parent) { this.parent = parent; @@ -697,23 +625,13 @@ public ProtoIntConverter(ParentValueContainer parent) { @Override public void addInt(int value) { - parent.addInt(value); - } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; + parent.add(value); } } - static final class ProtoLongConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoLongConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoLongConverter(ParentValueContainer parent) { this.parent = parent; @@ -721,23 +639,13 @@ public ProtoLongConverter(ParentValueContainer parent) { @Override public void addLong(long value) { - parent.addLong(value); - } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; + parent.add(value); } } - static final class ProtoStringConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoStringConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoStringConverter(ParentValueContainer parent) { this.parent = parent; @@ -748,22 +656,11 @@ public void addBinary(Binary binary) { String str = binary.toStringUsingUTF8(); parent.add(str); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoTimestampConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoTimestampConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; final LogicalTypeAnnotation.TimestampLogicalTypeAnnotation logicalTypeAnnotation; public ProtoTimestampConverter( @@ -787,21 +684,11 @@ public void addLong(long value) { break; } } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoDateConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoDateConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoDateConverter(ParentValueContainer parent) { this.parent = parent; @@ -817,21 +704,11 @@ public void addInt(int value) { .build(); parent.add(date); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoTimeConverter extends PrimitiveConverter implements ModifiableParentValueContainerHolder { + static final class ProtoTimeConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; final LogicalTypeAnnotation.TimeLogicalTypeAnnotation logicalTypeAnnotation; public ProtoTimeConverter( @@ -869,22 +746,11 @@ public void addLong(long value) { .build(); parent.add(timeOfDay); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoDoubleValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoDoubleValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoDoubleValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -894,22 +760,11 @@ public ProtoDoubleValueConverter(ParentValueContainer parent) { public void addDouble(double value) { parent.add(DoubleValue.of(value)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoFloatValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoFloatValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoFloatValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -919,22 +774,11 @@ public ProtoFloatValueConverter(ParentValueContainer parent) { public void addFloat(float value) { parent.add(FloatValue.of(value)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoInt64ValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoInt64ValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoInt64ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -944,22 +788,11 @@ public ProtoInt64ValueConverter(ParentValueContainer parent) { public void addLong(long value) { parent.add(Int64Value.of(value)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoUInt64ValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoUInt64ValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoUInt64ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -969,22 +802,11 @@ public ProtoUInt64ValueConverter(ParentValueContainer parent) { public void addLong(long value) { parent.add(UInt64Value.of(value)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoInt32ValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoInt32ValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoInt32ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -994,22 +816,11 @@ public ProtoInt32ValueConverter(ParentValueContainer parent) { public void addInt(int value) { parent.add(Int32Value.of(value)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoUInt32ValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoUInt32ValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoUInt32ValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -1019,22 +830,11 @@ public ProtoUInt32ValueConverter(ParentValueContainer parent) { public void addLong(long value) { parent.add(UInt32Value.of(Math.toIntExact(value))); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoBoolValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoBoolValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoBoolValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -1044,22 +844,11 @@ public ProtoBoolValueConverter(ParentValueContainer parent) { public void addBoolean(boolean value) { parent.add(BoolValue.of(value)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoStringValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoStringValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoStringValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -1070,22 +859,11 @@ public void addBinary(Binary binary) { String str = binary.toStringUsingUTF8(); parent.add(StringValue.of(str)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } - static final class ProtoBytesValueConverter extends PrimitiveConverter - implements ModifiableParentValueContainerHolder { + static final class ProtoBytesValueConverter extends PrimitiveConverter { - ParentValueContainer parent; + final ParentValueContainer parent; public ProtoBytesValueConverter(ParentValueContainer parent) { this.parent = parent; @@ -1096,16 +874,6 @@ public void addBinary(Binary binary) { ByteString byteString = ByteString.copyFrom(binary.toByteBuffer()); parent.add(BytesValue.of(byteString)); } - - @Override - public ParentValueContainer getParentValueContainer() { - return parent; - } - - @Override - public void setParentValueContainer(ParentValueContainer parentValueContainer) { - parent = parentValueContainer; - } } /** @@ -1130,29 +898,16 @@ public void setParentValueContainer(ParentValueContainer parentValueContainer) { * a repeated group named 'list', itself containing only one field called 'element' of the type of the repeated * object (can be a primitive as in this example or a group in case of a repeated message in protobuf). */ - final class ListConverter extends ModifiableGroupConverter { - private Converter converter; + final class ListConverter extends GroupConverter { + final ListWrapperConverter converter; - final class ListWrapperConverter extends ModifiableGroupConverter { - private Converter converter; + final class ListWrapperConverter extends GroupConverter { + final Converter converter; ListWrapperConverter(Converter converter) { this.converter = converter; } - @Override - int getFieldCount() { - return 1; - } - - @Override - void setFieldConverter(int fieldIndex, Converter converter) { - if (fieldIndex > 0) { - throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); - } - this.converter = converter; - } - @Override public Converter getConverter(int fieldIndex) { return converter; @@ -1204,23 +959,10 @@ public void start() {} @Override public void end() {} - - @Override - int getFieldCount() { - return 1; - } - - @Override - void setFieldConverter(int fieldIndex, Converter converter) { - if (fieldIndex > 0) { - throw new ParquetDecodingException("Unexpected multiple fields in the LIST wrapper"); - } - this.converter = converter; - } } - final class MapConverter extends ModifiableGroupConverter { - private Converter converter; + final class MapConverter extends GroupConverter { + final Converter converter; public MapConverter( Message.Builder parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Type parquetType) { @@ -1253,18 +995,5 @@ public void start() {} @Override public void end() {} - - @Override - int getFieldCount() { - return 1; - } - - @Override - void setFieldConverter(int fieldIndex, Converter converter) { - if (fieldIndex > 0) { - throw new ParquetDecodingException("Unexpected multiple fields in the MAP wrapper"); - } - MapConverter.this.converter = converter; - } } } diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java index cb6f89b056..c520c7bb69 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java @@ -58,16 +58,4 @@ public T getCurrentRecord() { public GroupConverter getRootConverter() { return root; } - - interface ModifiableParentValueContainerHolder { - ProtoMessageConverter.ParentValueContainer getParentValueContainer(); - - void setParentValueContainer(ProtoMessageConverter.ParentValueContainer parentValueContainer); - } - - abstract static class ModifiableGroupConverter extends GroupConverter { - abstract int getFieldCount(); - - abstract void setFieldConverter(int fieldIndex, Converter converter); - } } From 263cf6bf94a2138390855eab239d150b8eeb83b9 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Sat, 3 May 2025 22:47:15 +0100 Subject: [PATCH 13/15] clone converter structure - before implementing generate pvc --- .../benchmarks/ProtoDataGenerator.java | 117 +++++++-- .../benchmarks/ProtoReadBenchmarks.java | 9 +- .../parquet/proto/ByteBuddyCodeGen.java | 243 ++++++++++-------- .../parquet/proto/ProtoMessageConverter.java | 27 +- .../proto/ProtoRecordMaterializer.java | 1 - 5 files changed, 259 insertions(+), 138 deletions(-) diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java index c06a07fd8d..0a7c402e5b 100644 --- a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoDataGenerator.java @@ -75,8 +75,7 @@ private interface RecordGenerator extends IntFunction private static final RecordGeneratorFactory TEST_30_INT32 = fixedLenByteArraySize -> { final Test30Int32.Builder builder = Test30Int32.newBuilder(); - return i -> builder - .setField1(i) + return i -> builder.setField1(i) .setField2(i) .setField3(i) .setField4(i) @@ -111,24 +110,112 @@ private interface RecordGenerator extends IntFunction private static final RecordGeneratorFactory TEST_100_INT32 = fixedLenByteArraySize -> { final Test100Int32.Builder builder = Test100Int32.newBuilder(); - return i -> builder - .setF1(i).setF2(i).setF3(i).setF4(i).setF5(i).setF6(i).setF7(i).setF8(i).setF9(i).setF10(i) - .setF11(i).setF12(i).setF13(i).setF14(i).setF15(i).setF16(i).setF17(i).setF18(i).setF19(i).setF20(i) - .setF21(i).setF22(i).setF23(i).setF24(i).setF25(i).setF26(i).setF27(i).setF28(i).setF29(i).setF30(i) - .setF31(i).setF32(i).setF33(i).setF34(i).setF35(i).setF36(i).setF37(i).setF38(i).setF39(i).setF40(i) - .setF41(i).setF42(i).setF43(i).setF44(i).setF45(i).setF46(i).setF47(i).setF48(i).setF49(i).setF50(i) - .setF51(i).setF52(i).setF53(i).setF54(i).setF55(i).setF56(i).setF57(i).setF58(i).setF59(i).setF60(i) - .setF61(i).setF62(i).setF63(i).setF64(i).setF65(i).setF66(i).setF67(i).setF68(i).setF69(i).setF70(i) - .setF71(i).setF72(i).setF73(i).setF74(i).setF75(i).setF76(i).setF77(i).setF78(i).setF79(i).setF80(i) - .setF81(i).setF82(i).setF83(i).setF84(i).setF85(i).setF86(i).setF87(i).setF88(i).setF89(i).setF90(i) - .setF91(i).setF92(i).setF93(i).setF94(i).setF95(i).setF96(i).setF97(i).setF98(i).setF99(i).setF100(i); + return i -> builder.setF1(i) + .setF2(i) + .setF3(i) + .setF4(i) + .setF5(i) + .setF6(i) + .setF7(i) + .setF8(i) + .setF9(i) + .setF10(i) + .setF11(i) + .setF12(i) + .setF13(i) + .setF14(i) + .setF15(i) + .setF16(i) + .setF17(i) + .setF18(i) + .setF19(i) + .setF20(i) + .setF21(i) + .setF22(i) + .setF23(i) + .setF24(i) + .setF25(i) + .setF26(i) + .setF27(i) + .setF28(i) + .setF29(i) + .setF30(i) + .setF31(i) + .setF32(i) + .setF33(i) + .setF34(i) + .setF35(i) + .setF36(i) + .setF37(i) + .setF38(i) + .setF39(i) + .setF40(i) + .setF41(i) + .setF42(i) + .setF43(i) + .setF44(i) + .setF45(i) + .setF46(i) + .setF47(i) + .setF48(i) + .setF49(i) + .setF50(i) + .setF51(i) + .setF52(i) + .setF53(i) + .setF54(i) + .setF55(i) + .setF56(i) + .setF57(i) + .setF58(i) + .setF59(i) + .setF60(i) + .setF61(i) + .setF62(i) + .setF63(i) + .setF64(i) + .setF65(i) + .setF66(i) + .setF67(i) + .setF68(i) + .setF69(i) + .setF70(i) + .setF71(i) + .setF72(i) + .setF73(i) + .setF74(i) + .setF75(i) + .setF76(i) + .setF77(i) + .setF78(i) + .setF79(i) + .setF80(i) + .setF81(i) + .setF82(i) + .setF83(i) + .setF84(i) + .setF85(i) + .setF86(i) + .setF87(i) + .setF88(i) + .setF89(i) + .setF90(i) + .setF91(i) + .setF92(i) + .setF93(i) + .setF94(i) + .setF95(i) + .setF96(i) + .setF97(i) + .setF98(i) + .setF99(i) + .setF100(i); }; private static final RecordGeneratorFactory TEST_30_STRING = fixedLenByteArraySize -> { final Test30String.Builder builder = Test30String.newBuilder(); - return i -> builder - .setField1("setField1:" + i) + return i -> builder.setField1("setField1:" + i) .setField2("setField2:" + i) .setField3("setField3:" + i) .setField4("setField4:" + i) diff --git a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java index 17ea033fd3..7143bfe130 100644 --- a/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java +++ b/parquet-benchmarks/src/main/java/org/apache/parquet/benchmarks/ProtoReadBenchmarks.java @@ -19,7 +19,11 @@ package org.apache.parquet.benchmarks; +import static org.apache.parquet.benchmarks.BenchmarkFiles.configuration; +import static org.openjdk.jmh.annotations.Scope.Thread; + import com.google.protobuf.Message; +import java.io.IOException; import org.apache.hadoop.fs.Path; import org.apache.parquet.hadoop.ParquetReader; import org.apache.parquet.proto.ProtoParquetReader; @@ -29,12 +33,7 @@ import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.TearDown; import org.openjdk.jmh.infra.Blackhole; -import java.io.IOException; - -import static org.apache.parquet.benchmarks.BenchmarkFiles.configuration; -import static org.openjdk.jmh.annotations.Scope.Thread; @State(Thread) public class ProtoReadBenchmarks extends ReadBenchmarks { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 2da5f618e8..34d4421bd5 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -29,7 +29,6 @@ import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; -import com.google.protobuf.MapEntry; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.StringValue; @@ -109,27 +108,27 @@ import org.apache.parquet.proto.ProtoMessageConverter.ListConverter; import org.apache.parquet.proto.ProtoMessageConverter.MapConverter; import org.apache.parquet.proto.ProtoMessageConverter.ParentValueContainer; -import org.apache.parquet.proto.ProtoMessageConverter.SetFieldParentValueContainer; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoEnumConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoBinaryConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoBoolValueConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoBooleanConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoBytesValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoDateConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoDoubleConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoDoubleValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoEnumConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoFloatConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoFloatValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoInt32ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoInt64ValueConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoIntConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoLongConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoStringConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoTimestampConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoDateConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoStringValueConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoTimeConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoDoubleValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoFloatValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoInt64ValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoUInt64ValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoInt32ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoTimestampConverter; import org.apache.parquet.proto.ProtoMessageConverter.ProtoUInt32ValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoBoolValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoStringValueConverter; -import org.apache.parquet.proto.ProtoMessageConverter.ProtoBytesValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.ProtoUInt64ValueConverter; +import org.apache.parquet.proto.ProtoMessageConverter.SetFieldParentValueContainer; import org.apache.parquet.schema.MessageType; public class ByteBuddyCodeGen { @@ -2797,8 +2796,8 @@ interface MapEntryBuilder { private ProtoReadSupport.CodegenMode codegenMode; private ParquetConfiguration configuration; - public ProtoRecordMaterializerTransformer(ProtoReadSupport.CodegenMode codegenMode, - ParquetConfiguration configuration) { + public ProtoRecordMaterializerTransformer( + ProtoReadSupport.CodegenMode codegenMode, ParquetConfiguration configuration) { this.codegenMode = codegenMode; this.configuration = configuration; } @@ -2832,7 +2831,8 @@ private Converter transformChildConverter(Object parentBuilder, Converter conver return transformChildConverterProtoBinaryConverter(parentBuilder, (ProtoBinaryConverter) converter); } if (converter instanceof ProtoBooleanConverter) { - return transformChildConverterProtoBooleanConverter(parentBuilder, (ProtoBooleanConverter) converter); + return transformChildConverterProtoBooleanConverter( + parentBuilder, (ProtoBooleanConverter) converter); } if (converter instanceof ProtoDoubleConverter) { return transformChildConverterProtoDoubleConverter(parentBuilder, (ProtoDoubleConverter) converter); @@ -2850,7 +2850,8 @@ private Converter transformChildConverter(Object parentBuilder, Converter conver return transformChildConverterProtoStringConverter(parentBuilder, (ProtoStringConverter) converter); } if (converter instanceof ProtoTimestampConverter) { - return transformChildConverterProtoTimestampConverter(parentBuilder, (ProtoTimestampConverter) converter); + return transformChildConverterProtoTimestampConverter( + parentBuilder, (ProtoTimestampConverter) converter); } if (converter instanceof ProtoDateConverter) { return transformChildConverterProtoDateConverter(parentBuilder, (ProtoDateConverter) converter); @@ -2859,31 +2860,40 @@ private Converter transformChildConverter(Object parentBuilder, Converter conver return transformChildConverterProtoTimeConverter(parentBuilder, (ProtoTimeConverter) converter); } if (converter instanceof ProtoDoubleValueConverter) { - return transformChildConverterProtoDoubleValueConverter(parentBuilder, (ProtoDoubleValueConverter) converter); + return transformChildConverterProtoDoubleValueConverter( + parentBuilder, (ProtoDoubleValueConverter) converter); } if (converter instanceof ProtoFloatValueConverter) { - return transformChildConverterProtoFloatValueConverter(parentBuilder, (ProtoFloatValueConverter) converter); + return transformChildConverterProtoFloatValueConverter( + parentBuilder, (ProtoFloatValueConverter) converter); } if (converter instanceof ProtoInt64ValueConverter) { - return transformChildConverterProtoInt64ValueConverter(parentBuilder, (ProtoInt64ValueConverter) converter); + return transformChildConverterProtoInt64ValueConverter( + parentBuilder, (ProtoInt64ValueConverter) converter); } if (converter instanceof ProtoUInt64ValueConverter) { - return transformChildConverterProtoUInt64ValueConverter(parentBuilder, (ProtoUInt64ValueConverter) converter); + return transformChildConverterProtoUInt64ValueConverter( + parentBuilder, (ProtoUInt64ValueConverter) converter); } if (converter instanceof ProtoInt32ValueConverter) { - return transformChildConverterProtoInt32ValueConverter(parentBuilder, (ProtoInt32ValueConverter) converter); + return transformChildConverterProtoInt32ValueConverter( + parentBuilder, (ProtoInt32ValueConverter) converter); } if (converter instanceof ProtoUInt32ValueConverter) { - return transformChildConverterProtoUInt32ValueConverter(parentBuilder, (ProtoUInt32ValueConverter) converter); + return transformChildConverterProtoUInt32ValueConverter( + parentBuilder, (ProtoUInt32ValueConverter) converter); } if (converter instanceof ProtoBoolValueConverter) { - return transformChildConverterProtoBoolValueConverter(parentBuilder, (ProtoBoolValueConverter) converter); + return transformChildConverterProtoBoolValueConverter( + parentBuilder, (ProtoBoolValueConverter) converter); } if (converter instanceof ProtoStringValueConverter) { - return transformChildConverterProtoStringValueConverter(parentBuilder, (ProtoStringValueConverter) converter); + return transformChildConverterProtoStringValueConverter( + parentBuilder, (ProtoStringValueConverter) converter); } if (converter instanceof ProtoBytesValueConverter) { - return transformChildConverterProtoBytesValueConverter(parentBuilder, (ProtoBytesValueConverter) converter); + return transformChildConverterProtoBytesValueConverter( + parentBuilder, (ProtoBytesValueConverter) converter); } if (converter instanceof MapConverter) { return transformChildConverterMapConverter(parentBuilder, (MapConverter) converter); @@ -2894,24 +2904,31 @@ private Converter transformChildConverter(Object parentBuilder, Converter conver return converter; } - private Converter transformChildConverterProtoEnumConverter(Object parentBuilder, ProtoEnumConverter converter) { - return converter; + private Converter transformChildConverterProtoEnumConverter( + Object parentBuilder, ProtoEnumConverter converter) { + Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); + + if (fieldDescriptor == null) { + return converter; + } + + return new ProtoEnumConverter(generatePvc(parentBuilder, fieldDescriptor), converter); } private Converter transformChildConverterListConverter(Object parentBuilder, ListConverter converter) { + Converter listConverter = transformChildConverter(parentBuilder, converter.converter.converter); + GroupConverter wrapperConverter = new GroupConverter() { @Override public Converter getConverter(int fieldIndex) { - return transformChildConverter(parentBuilder, converter.converter.converter); + return listConverter; } @Override - public void start() { - } + public void start() {} @Override - public void end() { - } + public void end() {} }; return new GroupConverter() { @@ -2921,243 +2938,243 @@ public Converter getConverter(int fieldIndex) { } @Override - public void start() { - } + public void start() {} @Override - public void end() { - } + public void end() {} }; } private Converter transformChildConverterMapConverter(Object parentBuilder, MapConverter converter) { + Converter mapConverter = transformChildConverter(parentBuilder, converter.converter); + return new GroupConverter() { @Override public Converter getConverter(int fieldIndex) { - return transformChildConverter(parentBuilder, converter.converter); + return mapConverter; } @Override - public void start() { - } + public void start() {} @Override - public void end() { - } + public void end() {} }; } - private Converter transformChildConverterProtoBytesValueConverter(Object parentBuilder, - ProtoBytesValueConverter converter) { + private Converter transformChildConverterProtoBytesValueConverter( + Object parentBuilder, ProtoBytesValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoBytesValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoBytesValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoStringValueConverter(Object parentBuilder, - ProtoStringValueConverter converter) { + private Converter transformChildConverterProtoStringValueConverter( + Object parentBuilder, ProtoStringValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoStringValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoStringValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoBoolValueConverter(Object parentBuilder, - ProtoBoolValueConverter converter) { + private Converter transformChildConverterProtoBoolValueConverter( + Object parentBuilder, ProtoBoolValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoBoolValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoBoolValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoUInt32ValueConverter(Object parentBuilder, - ProtoUInt32ValueConverter converter) { + private Converter transformChildConverterProtoUInt32ValueConverter( + Object parentBuilder, ProtoUInt32ValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoUInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoUInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoInt32ValueConverter(Object parentBuilder, - ProtoInt32ValueConverter converter) { + private Converter transformChildConverterProtoInt32ValueConverter( + Object parentBuilder, ProtoInt32ValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoUInt64ValueConverter(Object parentBuilder, - ProtoUInt64ValueConverter converter) { + private Converter transformChildConverterProtoUInt64ValueConverter( + Object parentBuilder, ProtoUInt64ValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoUInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoUInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoInt64ValueConverter(Object parentBuilder, - ProtoInt64ValueConverter converter) { + private Converter transformChildConverterProtoInt64ValueConverter( + Object parentBuilder, ProtoInt64ValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoFloatValueConverter(Object parentBuilder, - ProtoFloatValueConverter converter) { + private Converter transformChildConverterProtoFloatValueConverter( + Object parentBuilder, ProtoFloatValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoFloatValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoFloatValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoDoubleValueConverter(Object parentBuilder, - ProtoDoubleValueConverter converter) { + private Converter transformChildConverterProtoDoubleValueConverter( + Object parentBuilder, ProtoDoubleValueConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoDoubleValueConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoDoubleValueConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoTimeConverter(Object parentBuilder, - ProtoTimeConverter converter) { + private Converter transformChildConverterProtoTimeConverter( + Object parentBuilder, ProtoTimeConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoTimeConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent), converter.logicalTypeAnnotation); + return new ProtoTimeConverter( + generatePvc(parentBuilder, fieldDescriptor), converter.logicalTypeAnnotation); } - private Converter transformChildConverterProtoDateConverter(Object parentBuilder, - ProtoDateConverter converter) { + private Converter transformChildConverterProtoDateConverter( + Object parentBuilder, ProtoDateConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoDateConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoDateConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoTimestampConverter(Object parentBuilder, - ProtoTimestampConverter converter) { + private Converter transformChildConverterProtoTimestampConverter( + Object parentBuilder, ProtoTimestampConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoTimestampConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent), converter.logicalTypeAnnotation); + return new ProtoTimestampConverter( + generatePvc(parentBuilder, fieldDescriptor), converter.logicalTypeAnnotation); } - private Converter transformChildConverterProtoBinaryConverter(Object parentBuilder, - ProtoBinaryConverter converter) { + private Converter transformChildConverterProtoBinaryConverter( + Object parentBuilder, ProtoBinaryConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoBinaryConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoBinaryConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoStringConverter(Object parentBuilder, - ProtoStringConverter converter) { + private Converter transformChildConverterProtoStringConverter( + Object parentBuilder, ProtoStringConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoStringConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoStringConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoLongConverter(Object parentBuilder, - ProtoLongConverter converter) { + private Converter transformChildConverterProtoLongConverter( + Object parentBuilder, ProtoLongConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoLongConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoLongConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoIntConverter(Object parentBuilder, - ProtoIntConverter converter) { + private Converter transformChildConverterProtoIntConverter( + Object parentBuilder, ProtoIntConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoIntConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoIntConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoFloatConverter(Object parentBuilder, - ProtoFloatConverter converter) { + private Converter transformChildConverterProtoFloatConverter( + Object parentBuilder, ProtoFloatConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoFloatConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoFloatConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoDoubleConverter(Object parentBuilder, - ProtoDoubleConverter converter) { + private Converter transformChildConverterProtoDoubleConverter( + Object parentBuilder, ProtoDoubleConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoDoubleConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoDoubleConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoBooleanConverter(Object parentBuilder, - ProtoBooleanConverter converter) { + private Converter transformChildConverterProtoBooleanConverter( + Object parentBuilder, ProtoBooleanConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { return converter; } - return new ProtoBooleanConverter(generatePvc(parentBuilder, fieldDescriptor, converter.parent)); + return new ProtoBooleanConverter(generatePvc(parentBuilder, fieldDescriptor)); } - private Converter transformChildConverterProtoMessageConverter(Object parentBuilder, - ProtoMessageConverter converter) { + private Converter transformChildConverterProtoMessageConverter( + Object parentBuilder, ProtoMessageConverter converter) { Descriptors.FieldDescriptor fieldDescriptor = getFieldDescriptor(converter.parent); if (fieldDescriptor == null) { @@ -3165,18 +3182,17 @@ private Converter transformChildConverterProtoMessageConverter(Object parentBuil } Object myBuilder = fieldDescriptor.isMapField() - // TODO - ? converter.myBuilder.build().toBuilder()//newMapEntryBuilder(parentBuilder, fieldDescriptor) + ? newMapEntryBuilder(parentBuilder, fieldDescriptor) : converter.myBuilder; Converter[] converters = converter.converters; Converter[] newConverters = new Converter[converters.length]; for (int i = 0; i < converters.length; i++) { Converter childConverter = converters[i]; - newConverters[i] = transformChildConverter(converter, childConverter); + newConverters[i] = transformChildConverter(myBuilder, childConverter); } - ParentValueContainer newPvc = generatePvc(parentBuilder, fieldDescriptor, converter.parent); + ParentValueContainer newPvc = generatePvc(parentBuilder, fieldDescriptor); return new PreBuiltProtoMessageConverter(newConverters, newPvc, myBuilder); } @@ -3186,8 +3202,11 @@ private Object newMapEntryBuilder(Object parentBuilder, Descriptors.FieldDescrip return messageBuilder.newBuilderForField(fieldDescriptor); } - private ParentValueContainer generatePvc(Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, - ParentValueContainer fallbackPvc) { + private ParentValueContainer generatePvc( + Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor) { + ParentValueContainer fallbackPvc = fieldDescriptor.isRepeated() + ? new AddRepeatedFieldParentValueContainer((Message.Builder) parentBuilder, fieldDescriptor) + : new SetFieldParentValueContainer((Message.Builder) parentBuilder, fieldDescriptor); return new ParentValueContainer() { @Override public void add(Object val) { @@ -3209,7 +3228,8 @@ private Descriptors.FieldDescriptor getFieldDescriptor(ParentValueContainer pvc) return setFieldPvc.getFieldDescriptor(); } if (pvc instanceof AddRepeatedFieldParentValueContainer) { - AddRepeatedFieldParentValueContainer addRepeatedFieldPvc = (AddRepeatedFieldParentValueContainer) pvc; + AddRepeatedFieldParentValueContainer addRepeatedFieldPvc = + (AddRepeatedFieldParentValueContainer) pvc; return addRepeatedFieldPvc.getFieldDescriptor(); } return null; @@ -3220,8 +3240,8 @@ private static class PreBuiltProtoMessageConverter extends GroupConverter { protected final ParentValueContainer parent; protected final Object myBuilder; - public PreBuiltProtoMessageConverter(Converter[] converters, ParentValueContainer parent, - Object myBuilder) { + public PreBuiltProtoMessageConverter( + Converter[] converters, ParentValueContainer parent, Object myBuilder) { this.converters = converters; this.parent = parent; this.myBuilder = myBuilder; @@ -3233,8 +3253,7 @@ public Converter getConverter(int fieldIndex) { } @Override - public void start() { - } + public void start() {} @Override public void end() { diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 230049a5c4..55c3e3ba65 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -298,7 +298,8 @@ protected Converter newScalarConverter( case BYTE_STRING: return new ProtoBinaryConverter(pvc); case ENUM: - return new ProtoEnumConverter(pvc, fieldDescriptor); + return new ProtoEnumConverter( + pvc, fieldDescriptor, extraMetadata, conf.getBoolean(CONFIG_ACCEPT_UNKNOWN_ENUM, false)); case INT: return new ProtoIntConverter(pvc); case LONG: @@ -432,7 +433,7 @@ public Descriptors.FieldDescriptor getFieldDescriptor() { } } - final class ProtoEnumConverter extends PrimitiveConverter { + static final class ProtoEnumConverter extends PrimitiveConverter { private final Descriptors.FieldDescriptor fieldType; private final Map enumLookup; @@ -441,14 +442,30 @@ final class ProtoEnumConverter extends PrimitiveConverter { private final Descriptors.EnumDescriptor enumType; private final String unknownEnumPrefix; private final boolean acceptUnknownEnum; + private final Map extraMetadata; - public ProtoEnumConverter(ParentValueContainer parent, Descriptors.FieldDescriptor fieldType) { + public ProtoEnumConverter( + ParentValueContainer parent, + Descriptors.FieldDescriptor fieldType, + Map extraMetadata, + boolean acceptUnknownEnum) { this.parent = parent; + this.extraMetadata = extraMetadata; this.fieldType = fieldType; this.enumType = fieldType.getEnumType(); this.enumLookup = makeLookupStructure(enumType); - unknownEnumPrefix = "UNKNOWN_ENUM_VALUE_" + enumType.getName() + "_"; - acceptUnknownEnum = conf.getBoolean(CONFIG_ACCEPT_UNKNOWN_ENUM, false); + this.unknownEnumPrefix = "UNKNOWN_ENUM_VALUE_" + enumType.getName() + "_"; + this.acceptUnknownEnum = acceptUnknownEnum; + } + + public ProtoEnumConverter(ParentValueContainer parent, ProtoEnumConverter from) { + this.parent = parent; + this.extraMetadata = from.extraMetadata; + this.fieldType = from.fieldType; + this.enumType = from.enumType; + this.enumLookup = from.enumLookup; + this.unknownEnumPrefix = from.unknownEnumPrefix; + this.acceptUnknownEnum = from.acceptUnknownEnum; } /** diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java index c520c7bb69..12386a4669 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoRecordMaterializer.java @@ -24,7 +24,6 @@ import org.apache.hadoop.conf.Configuration; import org.apache.parquet.conf.HadoopParquetConfiguration; import org.apache.parquet.conf.ParquetConfiguration; -import org.apache.parquet.io.api.Converter; import org.apache.parquet.io.api.GroupConverter; import org.apache.parquet.io.api.RecordMaterializer; import org.apache.parquet.schema.MessageType; From 13eb5e849d4f6abbc5af5c92987348c4c1af5544 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Tue, 6 May 2025 22:21:36 +0100 Subject: [PATCH 14/15] working codegen except for Map --- .../parquet/proto/ByteBuddyCodeGen.java | 323 ++++++++++++++++-- .../parquet/proto/ProtoMessageConverter.java | 10 +- 2 files changed, 303 insertions(+), 30 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 34d4421bd5..31f4b68d60 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -29,6 +29,7 @@ import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; +import com.google.protobuf.MapEntry; import com.google.protobuf.Message; import com.google.protobuf.MessageOrBuilder; import com.google.protobuf.StringValue; @@ -63,10 +64,12 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Stream; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.method.MethodList; import net.bytebuddy.description.modifier.Visibility; import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.dynamic.DynamicType; @@ -93,6 +96,7 @@ import net.bytebuddy.jar.asm.Label; import net.bytebuddy.jar.asm.MethodVisitor; import net.bytebuddy.jar.asm.Opcodes; +import net.bytebuddy.matcher.ElementMatcher; import net.bytebuddy.matcher.ElementMatchers; import net.bytebuddy.utility.JavaConstant; import org.apache.parquet.conf.ParquetConfiguration; @@ -2912,7 +2916,11 @@ private Converter transformChildConverterProtoEnumConverter( return converter; } - return new ProtoEnumConverter(generatePvc(parentBuilder, fieldDescriptor), converter); + return new ProtoEnumConverter(generatePvc( + parentBuilder, + fieldDescriptor, + Descriptors.EnumValueDescriptor.class + ), converter); } private Converter transformChildConverterListConverter(Object parentBuilder, ListConverter converter) { @@ -2970,7 +2978,13 @@ private Converter transformChildConverterProtoBytesValueConverter( return converter; } - return new ProtoBytesValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoBytesValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + BytesValue.class + ) + ); } private Converter transformChildConverterProtoStringValueConverter( @@ -2981,7 +2995,13 @@ private Converter transformChildConverterProtoStringValueConverter( return converter; } - return new ProtoStringValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoStringValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + StringValue.class + ) + ); } private Converter transformChildConverterProtoBoolValueConverter( @@ -2992,7 +3012,13 @@ private Converter transformChildConverterProtoBoolValueConverter( return converter; } - return new ProtoBoolValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoBoolValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + BoolValue.class + ) + ); } private Converter transformChildConverterProtoUInt32ValueConverter( @@ -3003,7 +3029,13 @@ private Converter transformChildConverterProtoUInt32ValueConverter( return converter; } - return new ProtoUInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoUInt32ValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + UInt32Value.class + ) + ); } private Converter transformChildConverterProtoInt32ValueConverter( @@ -3014,7 +3046,13 @@ private Converter transformChildConverterProtoInt32ValueConverter( return converter; } - return new ProtoInt32ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoInt32ValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + Int32Value.class + ) + ); } private Converter transformChildConverterProtoUInt64ValueConverter( @@ -3025,7 +3063,13 @@ private Converter transformChildConverterProtoUInt64ValueConverter( return converter; } - return new ProtoUInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoUInt64ValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + UInt64Value.class + ) + ); } private Converter transformChildConverterProtoInt64ValueConverter( @@ -3036,7 +3080,13 @@ private Converter transformChildConverterProtoInt64ValueConverter( return converter; } - return new ProtoInt64ValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoInt64ValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + Int64Value.class + ) + ); } private Converter transformChildConverterProtoFloatValueConverter( @@ -3047,7 +3097,13 @@ private Converter transformChildConverterProtoFloatValueConverter( return converter; } - return new ProtoFloatValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoFloatValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + FloatValue.class + ) + ); } private Converter transformChildConverterProtoDoubleValueConverter( @@ -3058,7 +3114,13 @@ private Converter transformChildConverterProtoDoubleValueConverter( return converter; } - return new ProtoDoubleValueConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoDoubleValueConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + DoubleValue.class + ) + ); } private Converter transformChildConverterProtoTimeConverter( @@ -3070,7 +3132,12 @@ private Converter transformChildConverterProtoTimeConverter( } return new ProtoTimeConverter( - generatePvc(parentBuilder, fieldDescriptor), converter.logicalTypeAnnotation); + generatePvc( + parentBuilder, + fieldDescriptor, + TimeOfDay.class + ), converter.logicalTypeAnnotation + ); } private Converter transformChildConverterProtoDateConverter( @@ -3081,7 +3148,13 @@ private Converter transformChildConverterProtoDateConverter( return converter; } - return new ProtoDateConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoDateConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + Date.class + ) + ); } private Converter transformChildConverterProtoTimestampConverter( @@ -3093,7 +3166,12 @@ private Converter transformChildConverterProtoTimestampConverter( } return new ProtoTimestampConverter( - generatePvc(parentBuilder, fieldDescriptor), converter.logicalTypeAnnotation); + generatePvc( + parentBuilder, + fieldDescriptor, + Timestamp.class + ), converter.logicalTypeAnnotation + ); } private Converter transformChildConverterProtoBinaryConverter( @@ -3104,7 +3182,13 @@ private Converter transformChildConverterProtoBinaryConverter( return converter; } - return new ProtoBinaryConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoBinaryConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + ByteString.class + ) + ); } private Converter transformChildConverterProtoStringConverter( @@ -3115,7 +3199,13 @@ private Converter transformChildConverterProtoStringConverter( return converter; } - return new ProtoStringConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoStringConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + String.class + ) + ); } private Converter transformChildConverterProtoLongConverter( @@ -3126,7 +3216,13 @@ private Converter transformChildConverterProtoLongConverter( return converter; } - return new ProtoLongConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoLongConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + long.class + ) + ); } private Converter transformChildConverterProtoIntConverter( @@ -3137,7 +3233,13 @@ private Converter transformChildConverterProtoIntConverter( return converter; } - return new ProtoIntConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoIntConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + int.class + ) + ); } private Converter transformChildConverterProtoFloatConverter( @@ -3148,7 +3250,13 @@ private Converter transformChildConverterProtoFloatConverter( return converter; } - return new ProtoFloatConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoFloatConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + float.class + ) + ); } private Converter transformChildConverterProtoDoubleConverter( @@ -3159,7 +3267,13 @@ private Converter transformChildConverterProtoDoubleConverter( return converter; } - return new ProtoDoubleConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoDoubleConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + double.class + ) + ); } private Converter transformChildConverterProtoBooleanConverter( @@ -3170,7 +3284,13 @@ private Converter transformChildConverterProtoBooleanConverter( return converter; } - return new ProtoBooleanConverter(generatePvc(parentBuilder, fieldDescriptor)); + return new ProtoBooleanConverter( + generatePvc( + parentBuilder, + fieldDescriptor, + boolean.class + ) + ); } private Converter transformChildConverterProtoMessageConverter( @@ -3192,7 +3312,7 @@ private Converter transformChildConverterProtoMessageConverter( newConverters[i] = transformChildConverter(myBuilder, childConverter); } - ParentValueContainer newPvc = generatePvc(parentBuilder, fieldDescriptor); + ParentValueContainer newPvc = generatePvc(parentBuilder, fieldDescriptor, myBuilder.getClass()); return new PreBuiltProtoMessageConverter(newConverters, newPvc, myBuilder); } @@ -3203,14 +3323,167 @@ private Object newMapEntryBuilder(Object parentBuilder, Descriptors.FieldDescrip } private ParentValueContainer generatePvc( - Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor) { + Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Class valueType) { + if (!fieldDescriptor.isMapField() && !(parentBuilder instanceof MapEntry.Builder)) { + return getRegularFieldPvc(parentBuilder, fieldDescriptor, valueType); + } + return getDefaultPvc((Message.Builder) parentBuilder, fieldDescriptor, valueType); + } + + private ParentValueContainer getRegularFieldPvc(Object parentBuilder, + Descriptors.FieldDescriptor fieldDescriptor, + Class valueType) { + return new Supplier() { + private DynamicType.Builder classBuilder; + + @Override + public ParentValueContainer get() { + Class parentBuilderClass = parentBuilder.getClass(); + + String fieldNameForMethod = ReflectionUtil.getFieldNameForMethod(fieldDescriptor); + TypeDescription parentBuilderTypeDef = TypeDescription.ForLoadedType.of(parentBuilderClass); + MethodList parentBuilderMethods = parentBuilderTypeDef.getDeclaredMethods(); + String setterPrefix = fieldDescriptor.isRepeated() ? "add" : "set"; + boolean isEnum = fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.ENUM; + boolean supportUnknownValues = isEnum + && !fieldDescriptor.getEnumType().isClosed() + && !fieldDescriptor.legacyEnumFieldTreatedAsClosed(); + String setterSuffix = supportUnknownValues ? "Value" : ""; + TypeDescription enumType = isEnum ? + parentBuilderMethods + .filter(ElementMatchers.named(setterPrefix + fieldNameForMethod + setterSuffix)) + .getOnly() + .asTypeToken() + .getParameterTypes() + .get(0) + : null; + + boolean isMessageBuilder = Message.Builder.class.isAssignableFrom(valueType); + boolean isMessage = Message.class.isAssignableFrom(valueType); + ElementMatcher setterArgumentMatcher = + isMessageBuilder || isMessage + ? ElementMatchers.takesArguments(valueType) + : ElementMatchers.any(); + + MethodDescription.InDefinedShape parentBuilderSetter = + parentBuilderMethods.filter( + ElementMatchers.named(setterPrefix + fieldNameForMethod + setterSuffix).and( + setterArgumentMatcher + ) + ).getOnly(); + + classBuilder = new ByteBuddy() + .subclass(ParentValueContainer.class) + .modifiers(Visibility.PUBLIC) + .name(ParentValueContainer.class.getName() + "$Generated$" + + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); + + TypeDescription.Generic parentBuilderType = + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(parentBuilderClass); + FieldDescription.Latent parentBuilderFieldDesc = new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "parent", Modifier.PRIVATE | Modifier.FINAL, parentBuilderType)); + classBuilder = classBuilder.define(parentBuilderFieldDesc); + + classBuilder = classBuilder + .define(new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + MethodDescription.CONSTRUCTOR_INTERNAL_NAME, + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.singletonList(parentBuilderType)))) + .intercept(MethodCall.invoke(ReflectionUtil.getConstructor( + ParentValueContainer.class)) + .andThen(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar parentVar = localVars.register(parentBuilderClass)) { + add( + MethodVariableAccess.loadThis(), + parentVar.load(), + FieldAccess.forField(parentBuilderFieldDesc) + .write()); + } + } + add(Codegen.returnVoid()); + } + })); + + String pvcMethodNameSuffix = valueType.isPrimitive() ? + valueType.getName().substring(0, 1).toUpperCase() + valueType.getName().substring(1) : + ""; + + classBuilder = classBuilder + .method(ElementMatchers.named("add" + pvcMethodNameSuffix)) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar valueVar = localVars.register(valueType.isPrimitive() ? valueType : Object.class)) { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(parentBuilderFieldDesc) + .read(), + valueVar.load()); + if (isEnum) { + add(TypeCasting.to( + TypeDescription.ForLoadedType.of(valueType))); + if (supportUnknownValues) { + add(Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(valueType, "getNumber"))); + } else { + MethodDescription.InDefinedShape valueOfMethod = enumType.getDeclaredMethods() + .filter(ElementMatchers + .hasMethodName("valueOf") + .and(ElementMatchers.takesArguments(Descriptors.EnumValueDescriptor.class))) + .getOnly(); + add(MethodInvocation.invoke(valueOfMethod)); + } + } else if (!valueType.isPrimitive()) { + add(TypeCasting.to( + TypeDescription.ForLoadedType.of(valueType))); + } + add(MethodInvocation.invoke(parentBuilderSetter)); + if (isMessageBuilder) { + add(valueVar.load(), + TypeCasting.to( + TypeDescription.ForLoadedType.of(valueType)), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(valueType, "clear")) + ); + } + add(Codegen.returnVoid()); + } + } + } + }); + + DynamicType.Unloaded unloaded = classBuilder.make(); + Class pvcClass = unloaded.load( + this.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) + .getLoaded(); + return ReflectionUtil.newInstance( + ReflectionUtil.getConstructor(pvcClass, parentBuilderClass), parentBuilder); + } + }.get(); + } + + private static ParentValueContainer getDefaultPvc(Message.Builder parentBuilder, + Descriptors.FieldDescriptor fieldDescriptor, + Class valueType) { ParentValueContainer fallbackPvc = fieldDescriptor.isRepeated() - ? new AddRepeatedFieldParentValueContainer((Message.Builder) parentBuilder, fieldDescriptor) - : new SetFieldParentValueContainer((Message.Builder) parentBuilder, fieldDescriptor); + ? new AddRepeatedFieldParentValueContainer(parentBuilder, fieldDescriptor) + : new SetFieldParentValueContainer(parentBuilder, fieldDescriptor); + + boolean isBuilder = Message.Builder.class.isAssignableFrom(valueType); + return new ParentValueContainer() { @Override public void add(Object val) { - if (val instanceof Message.Builder) { + if (isBuilder) { Message.Builder builder = (Message.Builder) val; Message message = builder.build(); fallbackPvc.add(message); diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java index 55c3e3ba65..fd8bece4be 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ProtoMessageConverter.java @@ -600,7 +600,7 @@ public ProtoBooleanConverter(ParentValueContainer parent) { @Override public void addBoolean(boolean value) { - parent.add(value); + parent.addBoolean(value); } } @@ -614,7 +614,7 @@ public ProtoDoubleConverter(ParentValueContainer parent) { @Override public void addDouble(double value) { - parent.add(value); + parent.addDouble(value); } } @@ -628,7 +628,7 @@ public ProtoFloatConverter(ParentValueContainer parent) { @Override public void addFloat(float value) { - parent.add(value); + parent.addFloat(value); } } @@ -642,7 +642,7 @@ public ProtoIntConverter(ParentValueContainer parent) { @Override public void addInt(int value) { - parent.add(value); + parent.addInt(value); } } @@ -656,7 +656,7 @@ public ProtoLongConverter(ParentValueContainer parent) { @Override public void addLong(long value) { - parent.add(value); + parent.addLong(value); } } From 1c3c15d8fa8d861deb906a32f055743a9ffcd624 Mon Sep 17 00:00:00 2001 From: Igor Kamyshnikov Date: Wed, 14 May 2025 00:40:48 +0100 Subject: [PATCH 15/15] custom MapBuilders need refactoring - a chain of ClassLoaders when Map is involved --- .../parquet/proto/ByteBuddyCodeGen.java | 435 +++++++++++++++++- 1 file changed, 431 insertions(+), 4 deletions(-) diff --git a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java index 31f4b68d60..165bd2e7cf 100644 --- a/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java +++ b/parquet-protobuf/src/main/java/org/apache/parquet/proto/ByteBuddyCodeGen.java @@ -65,6 +65,7 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.Stream; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; @@ -85,8 +86,11 @@ import net.bytebuddy.implementation.bytecode.assign.TypeCasting; import net.bytebuddy.implementation.bytecode.collection.ArrayAccess; import net.bytebuddy.implementation.bytecode.collection.ArrayFactory; +import net.bytebuddy.implementation.bytecode.constant.DoubleConstant; +import net.bytebuddy.implementation.bytecode.constant.FloatConstant; import net.bytebuddy.implementation.bytecode.constant.IntegerConstant; import net.bytebuddy.implementation.bytecode.constant.JavaConstantValue; +import net.bytebuddy.implementation.bytecode.constant.LongConstant; import net.bytebuddy.implementation.bytecode.constant.TextConstant; import net.bytebuddy.implementation.bytecode.member.FieldAccess; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; @@ -3318,8 +3322,319 @@ private Converter transformChildConverterProtoMessageConverter( } private Object newMapEntryBuilder(Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor) { - Message.Builder messageBuilder = (Message.Builder) parentBuilder; - return messageBuilder.newBuilderForField(fieldDescriptor); + return new Supplier() { + private DynamicType.Builder classBuilder; + + @Override + public Object get() { + List mapFields = fieldDescriptor.getMessageType().getFields(); + Descriptors.FieldDescriptor keyField = mapFields.get(0); + Descriptors.FieldDescriptor valueField = mapFields.get(1); + Class keyType = getMapEntryKeyType(parentBuilder.getClass(), keyField); + Class valueType = getMapEntryValueType(parentBuilder.getClass(), fieldDescriptor, valueField); + String setValueMethodName = valueField.getJavaType() == Descriptors.FieldDescriptor.JavaType.ENUM + && int.class.equals(valueType) + ? "setValueValue" : "setValue"; + + classBuilder = new ByteBuddy() + .subclass(Object.class) + .modifiers(Visibility.PUBLIC) + .name(ByteBuddyCodeGen.class.getName() + "$MapBuilder$Generated$" + + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); + + MethodDescription.Latent clearMethodDesc = new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + "clear", + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.emptyList())); + + TypeDescription.Generic keyTypeGen = TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(keyType); + TypeDescription.Generic valueTypeGen = TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(valueType); + + FieldDescription.Latent keyFieldDesc = new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "key", + Visibility.PRIVATE.getMask(), + keyTypeGen, + Collections.emptyList())); + + FieldDescription.Latent valueFieldDesc = new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "value", + Visibility.PRIVATE.getMask(), + valueTypeGen, + Collections.emptyList())); + + MethodDescription.Latent getKeyMethodDesc = new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + "getKey", + Visibility.PUBLIC.getMask(), + keyTypeGen, + Collections.emptyList())); + + MethodDescription.Latent getValueMethodDesc = new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + "getValue", + Visibility.PUBLIC.getMask(), + valueTypeGen, + Collections.emptyList())); + + MethodDescription.Latent setKeyMethodDesc = new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + "setKey", + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.singletonList(keyTypeGen))); + + Class valueBuilderType; + if (valueField.getJavaType() == Descriptors.FieldDescriptor.JavaType.MESSAGE) { + try { + valueBuilderType = valueType.getDeclaredMethod("newBuilder").invoke(null).getClass(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + valueBuilderType = null; + } + + MethodDescription.Latent setValueMethodDesc = new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + setValueMethodName, + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.singletonList(TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(valueBuilderType != null ? valueBuilderType : valueType)))); + + + classBuilder = classBuilder + .constructor(ElementMatchers.any()) + .intercept(MethodCall.invoke(ReflectionUtil.getConstructor( + Object.class)) + .andThen(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + add( + MethodVariableAccess.loadThis(), + MethodInvocation.invoke(clearMethodDesc)); + } + add(Codegen.returnVoid()); + } + })); + + classBuilder = classBuilder.define(keyFieldDesc); + classBuilder = classBuilder.define(valueFieldDesc); + classBuilder = classBuilder.define(getKeyMethodDesc) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(keyFieldDesc).read()); + } + add(MethodReturn.of(keyTypeGen)); + } + }); + classBuilder = classBuilder.define(setKeyMethodDesc) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar v = localVars.register(TypeDescription.ForLoadedType.of(keyType))) { + add( + MethodVariableAccess.loadThis(), + v.load(), + FieldAccess.forField(keyFieldDesc).write()); + } + } + add(Codegen.returnVoid()); + } + }); + classBuilder = classBuilder.define(getValueMethodDesc) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(valueFieldDesc).read()); + } + add(MethodReturn.of(valueTypeGen)); + } + }); + classBuilder = classBuilder.define(setValueMethodDesc) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar v = localVars.register(TypeDescription.ForLoadedType.of(valueBuilderType != null ? valueBuilderType : valueType))) { + add( + MethodVariableAccess.loadThis(), + v.load()); + if (valueBuilderType != null) { + add(Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(valueBuilderType, "build"))); + } + add(FieldAccess.forField(valueFieldDesc).write()); + } + } + add(Codegen.returnVoid()); + } + }); + classBuilder = classBuilder.define(clearMethodDesc) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + add(MethodVariableAccess.loadThis()); + switch (keyField.getJavaType()) { + case INT: + add(IntegerConstant.forValue(0)); + break; + case LONG: + add(LongConstant.forValue(0L)); + break; + case FLOAT: + add(FloatConstant.forValue(0.0f)); + break; + case DOUBLE: + add(DoubleConstant.forValue(0.0)); + break; + case BOOLEAN: + add(IntegerConstant.forValue(false)); + break; + case STRING: + add(new TextConstant("")); + break; + default: + throw new IllegalStateException(); + } + add(FieldAccess.forField(keyFieldDesc).write()); + add(MethodVariableAccess.loadThis()); + switch (valueField.getJavaType()) { + case INT: + add(IntegerConstant.forValue(0)); + break; + case LONG: + add(LongConstant.forValue(0L)); + break; + case FLOAT: + add(FloatConstant.forValue(0.0f)); + break; + case DOUBLE: + add(DoubleConstant.forValue(0.0)); + break; + case BOOLEAN: + add(IntegerConstant.forValue(false)); + break; + case STRING: + add(new TextConstant("")); + break; + case ENUM: + if (valueType.equals(int.class)) { + add(IntegerConstant.forValue(0)); + } else { + add(IntegerConstant.forValue(0), + Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(valueType, "forNumber", int.class))); + } + break; + case MESSAGE: + add(Codegen.invokeMethod( + ReflectionUtil.getDeclaredMethod(valueType, "getDefaultInstance"))); + break; + default: + throw new IllegalStateException(); + } + add(FieldAccess.forField(valueFieldDesc).write()); + } + add(Codegen.returnVoid()); + } + }); + + DynamicType.Unloaded unloaded = classBuilder.make(); + Class mapBuilderClass = unloaded.load( + parentBuilder.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) + .getLoaded(); + return ReflectionUtil.newInstance(ReflectionUtil.getConstructor(mapBuilderClass)); + } + }.get(); + } + + private Class getMapEntryValueType(Class messageBuilderClass, Descriptors.FieldDescriptor mapFieldDescriptor, Descriptors.FieldDescriptor valueField) { + switch (valueField.getJavaType()) { + case INT: + return int.class; + case LONG: + return long.class; + case FLOAT: + return float.class; + case DOUBLE: + return double.class; + case BOOLEAN: + return boolean.class; + case STRING: + return String.class; + case BYTE_STRING: + return ByteString.class; + case ENUM: { + Descriptors.EnumDescriptor enumType = valueField.getEnumType(); + boolean hasValueSetter = !enumType.isClosed() && !valueField.legacyEnumFieldTreatedAsClosed(); + if (hasValueSetter) { + return int.class; + } + } + } + switch (valueField.getJavaType()) { + case ENUM: + case MESSAGE: { + String mapField = ReflectionUtil.getFieldNameForMethod(mapFieldDescriptor); + List putMethods = Arrays.stream(messageBuilderClass.getDeclaredMethods()).filter(x -> x.getName().equals("put" + mapField)) + .collect(Collectors.toList()); + if (putMethods.size() != 1) { + throw new IllegalStateException("Expected one put method for map field: " + mapField); + } + Method putMethod = putMethods.get(0); + Class[] parameterTypes = putMethod.getParameterTypes(); + if (parameterTypes.length != 2) { + throw new IllegalStateException("Expected two parameters for put method: " + putMethod); + } + return parameterTypes[1]; + } + } + throw new IllegalStateException("Unsupported value type: " + valueField.getJavaType()); + }; + + private Class getMapEntryKeyType(Class messageBuilderClass, Descriptors.FieldDescriptor keyField) { + switch (keyField.getJavaType()) { + case INT: + return int.class; + case LONG: + return long.class; + case FLOAT: + return float.class; + case DOUBLE: + return double.class; + case BOOLEAN: + return boolean.class; + case STRING: + return String.class; + default: + throw new IllegalStateException("Unsupported key type: " + keyField.getJavaType()); + } } private ParentValueContainer generatePvc( @@ -3327,7 +3642,119 @@ private ParentValueContainer generatePvc( if (!fieldDescriptor.isMapField() && !(parentBuilder instanceof MapEntry.Builder)) { return getRegularFieldPvc(parentBuilder, fieldDescriptor, valueType); } - return getDefaultPvc((Message.Builder) parentBuilder, fieldDescriptor, valueType); + return getMapFieldPvc(parentBuilder, fieldDescriptor, valueType); +// return getDefaultPvc((Message.Builder) parentBuilder, fieldDescriptor, valueType); + } + + private ParentValueContainer getMapFieldPvc(Object parentBuilder, Descriptors.FieldDescriptor fieldDescriptor, Class mapBuilderType) { + return new Supplier() { + private DynamicType.Builder classBuilder; + + @Override + public ParentValueContainer get() { + Class parentBuilderClass = parentBuilder.getClass(); + String fieldNameForMethod = ReflectionUtil.getFieldNameForMethod(fieldDescriptor); + TypeDescription parentBuilderTypeDef = TypeDescription.ForLoadedType.of(parentBuilderClass); + MethodList parentBuilderMethods = parentBuilderTypeDef.getDeclaredMethods(); + String setterPrefix = "put"; + String setterSuffix = + Arrays.stream(mapBuilderType.getDeclaredMethods()).anyMatch(x -> x.getName().equals("setValueValue")) + ? "Value" : ""; + + ElementMatcher setterArgumentMatcher = + ElementMatchers.takesArguments( + Arrays.stream(mapBuilderType.getDeclaredMethods()).filter(x -> x.getName().equals("getKey")) + .map(x -> x.getReturnType()).findFirst().get(), + Arrays.stream(mapBuilderType.getDeclaredMethods()).filter(x -> x.getName().equals("getValue")) + .map(x -> x.getReturnType()).findFirst().get() + ); + + MethodDescription.InDefinedShape parentBuilderSetter = + parentBuilderMethods.filter( + ElementMatchers.named(setterPrefix + fieldNameForMethod + setterSuffix).and( + setterArgumentMatcher + ) + ).getOnly(); + + classBuilder = new ByteBuddy() + .subclass(ParentValueContainer.class) + .modifiers(Visibility.PUBLIC) + .name(ParentValueContainer.class.getName() + "$Generated$" + + BYTE_BUDDY_CLASS_SEQUENCE.incrementAndGet()); + + TypeDescription.Generic parentBuilderType = + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(parentBuilderClass); + FieldDescription.Latent parentBuilderFieldDesc = new FieldDescription.Latent( + classBuilder.toTypeDescription(), + new FieldDescription.Token( + "parent", Modifier.PRIVATE | Modifier.FINAL, parentBuilderType)); + classBuilder = classBuilder.define(parentBuilderFieldDesc); + + classBuilder = classBuilder + .define(new MethodDescription.Latent( + classBuilder.toTypeDescription(), + new MethodDescription.Token( + MethodDescription.CONSTRUCTOR_INTERNAL_NAME, + Visibility.PUBLIC.getMask(), + TypeDescription.Generic.OfNonGenericType.ForLoadedType.of(void.class), + Collections.singletonList(parentBuilderType)))) + .intercept(MethodCall.invoke(ReflectionUtil.getConstructor( + ParentValueContainer.class)) + .andThen(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar parentVar = localVars.register(parentBuilderClass)) { + add( + MethodVariableAccess.loadThis(), + parentVar.load(), + FieldAccess.forField(parentBuilderFieldDesc) + .write()); + } + } + add(Codegen.returnVoid()); + } + })); + + String pvcMethodNameSuffix = ""; + + classBuilder = classBuilder + .method(ElementMatchers.named("add" + pvcMethodNameSuffix)) + .intercept(new Implementations() { + { + CodeGenUtils.LocalVars localVars = new CodeGenUtils.LocalVars(); + try (LocalVar thisLocalVar = + localVars.register(classBuilder.toTypeDescription())) { + try (LocalVar valueVar = localVars.register(Object.class)) { + add( + MethodVariableAccess.loadThis(), + FieldAccess.forField(parentBuilderFieldDesc) + .read(), + valueVar.load(), + TypeCasting.to(TypeDescription.ForLoadedType.of(mapBuilderType)), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(mapBuilderType, "getKey")), + valueVar.load(), + TypeCasting.to(TypeDescription.ForLoadedType.of(mapBuilderType)), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(mapBuilderType, "getValue")), + MethodInvocation.invoke(parentBuilderSetter), + valueVar.load(), + TypeCasting.to(TypeDescription.ForLoadedType.of(mapBuilderType)), + Codegen.invokeMethod(ReflectionUtil.getDeclaredMethod(mapBuilderType, "clear"))); + add(Codegen.returnVoid()); + } + } + } + }); + + DynamicType.Unloaded unloaded = classBuilder.make(); + Class pvcClass = unloaded.load( + mapBuilderType.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) + .getLoaded(); + return ReflectionUtil.newInstance( + ReflectionUtil.getConstructor(pvcClass, parentBuilderClass), parentBuilder); + } + }.get(); } private ParentValueContainer getRegularFieldPvc(Object parentBuilder, @@ -3463,7 +3890,7 @@ public ParentValueContainer get() { DynamicType.Unloaded unloaded = classBuilder.make(); Class pvcClass = unloaded.load( - this.getClass().getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) + parentBuilderClass.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) .getLoaded(); return ReflectionUtil.newInstance( ReflectionUtil.getConstructor(pvcClass, parentBuilderClass), parentBuilder);