From 7fdcb9657d6829bca830f08d1ba64fb2b38abda4 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:00:48 +0100 Subject: [PATCH 01/11] Output dict encoded vectors as Avro enums (or decode them if that is not possible) --- .../arrow/adapter/avro/ArrowToAvroUtils.java | 212 +++++++++++++++--- 1 file changed, 179 insertions(+), 33 deletions(-) diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java index 87b594af9e..67f0c66112 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java @@ -17,7 +17,10 @@ package org.apache.arrow.adapter.avro; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; import org.apache.arrow.adapter.avro.producers.AvroBigIntProducer; import org.apache.arrow.adapter.avro.producers.AvroBooleanProducer; import org.apache.arrow.adapter.avro.producers.AvroBytesProducer; @@ -96,11 +99,15 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.Text; import org.apache.avro.LogicalType; import org.apache.avro.LogicalTypes; import org.apache.avro.Schema; @@ -162,17 +169,29 @@ public class ArrowToAvroUtils { * may be nullable. Record types must contain at least one child field and cannot contain multiple * fields with the same name * + *

String fields that are dictionary-encoded will be represented as an Avro enum, so long as all + * the values meet the restrictions on Avro enums (non-null, valid identifiers). Other data types + * that are dictionary encoded, or string fields that do not meet the avro requirements, will be output + * as their decoded type. + * * @param arrowFields The arrow fields used to generate the Avro schema * @param typeName Name of the top level Avro record type * @param namespace Namespace of the top level Avro record type + * @param dictionaries A dictionary provider is required if any fields use dictionary encoding * @return An Avro record schema for the given list of fields, with the specified name and * namespace */ public static Schema createAvroSchema( - List arrowFields, String typeName, String namespace) { + List arrowFields, String typeName, String namespace, DictionaryProvider dictionaries) { SchemaBuilder.RecordBuilder assembler = SchemaBuilder.record(typeName).namespace(namespace); - return buildRecordSchema(assembler, arrowFields, namespace); + return buildRecordSchema(assembler, arrowFields, namespace, dictionaries); + } + + /** Overload provided for convenience, sets dictionaries = null. */ + public static Schema createAvroSchema( + List arrowFields, String typeName, String namespace) { + return createAvroSchema(arrowFields, typeName, namespace, null); } /** Overload provided for convenience, sets namespace = null. */ @@ -185,61 +204,83 @@ public static Schema createAvroSchema(List arrowFields) { return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME); } + /** + * Overload provided for convenience, sets name = GENERIC_RECORD_TYPE_NAME and namespace = null. + */ + public static Schema createAvroSchema(List arrowFields, DictionaryProvider dictionaries) { + return createAvroSchema(arrowFields, GENERIC_RECORD_TYPE_NAME, null, dictionaries); + } + private static T buildRecordSchema( - SchemaBuilder.RecordBuilder builder, List fields, String namespace) { + SchemaBuilder.RecordBuilder builder, + List fields, + String namespace, + DictionaryProvider dictionaries) { if (fields.isEmpty()) { throw new IllegalArgumentException("Record field must have at least one child field"); } SchemaBuilder.FieldAssembler assembler = builder.namespace(namespace).fields(); for (Field field : fields) { - assembler = buildFieldSchema(assembler, field, namespace); + assembler = buildFieldSchema(assembler, field, namespace, dictionaries); } return assembler.endRecord(); } private static SchemaBuilder.FieldAssembler buildFieldSchema( - SchemaBuilder.FieldAssembler assembler, Field field, String namespace) { + SchemaBuilder.FieldAssembler assembler, + Field field, + String namespace, + DictionaryProvider dictionaries) { return assembler .name(field.getName()) - .type(buildTypeSchema(SchemaBuilder.builder(), field, namespace)) + .type(buildTypeSchema(SchemaBuilder.builder(), field, namespace, dictionaries)) .noDefault(); } private static T buildTypeSchema( - SchemaBuilder.TypeBuilder builder, Field field, String namespace) { + SchemaBuilder.TypeBuilder builder, + Field field, + String namespace, + DictionaryProvider dictionaries) { // Nullable unions need special handling, since union types cannot be directly nested if (field.getType().getTypeID() == ArrowType.ArrowTypeID.Union) { boolean unionNullable = field.getChildren().stream().anyMatch(Field::isNullable); if (unionNullable) { SchemaBuilder.UnionAccumulator union = builder.unionOf().nullType(); - return addTypesToUnion(union, field.getChildren(), namespace); + return addTypesToUnion(union, field.getChildren(), namespace, dictionaries); } else { Field headType = field.getChildren().get(0); List tailTypes = field.getChildren().subList(1, field.getChildren().size()); SchemaBuilder.UnionAccumulator union = - buildBaseTypeSchema(builder.unionOf(), headType, namespace); - return addTypesToUnion(union, tailTypes, namespace); + buildBaseTypeSchema(builder.unionOf(), headType, namespace, dictionaries); + return addTypesToUnion(union, tailTypes, namespace, dictionaries); } } else if (field.isNullable()) { - return buildBaseTypeSchema(builder.nullable(), field, namespace); + return buildBaseTypeSchema(builder.nullable(), field, namespace, dictionaries); } else { - return buildBaseTypeSchema(builder, field, namespace); + return buildBaseTypeSchema(builder, field, namespace, dictionaries); } } private static T buildArraySchema( - SchemaBuilder.ArrayBuilder builder, Field listField, String namespace) { + SchemaBuilder.ArrayBuilder builder, + Field listField, + String namespace, + DictionaryProvider dictionaries) { if (listField.getChildren().size() != 1) { throw new IllegalArgumentException("List field must have exactly one child field"); } Field itemField = listField.getChildren().get(0); - return buildTypeSchema(builder.items(), itemField, namespace); + return buildTypeSchema(builder.items(), itemField, namespace, dictionaries); } private static T buildMapSchema( - SchemaBuilder.MapBuilder builder, Field mapField, String namespace) { + SchemaBuilder.MapBuilder builder, + Field mapField, + String namespace, + DictionaryProvider dictionaries) { if (mapField.getChildren().size() != 1) { throw new IllegalArgumentException("Map field must have exactly one child field"); } @@ -253,11 +294,14 @@ private static T buildMapSchema( throw new IllegalArgumentException( "Map keys must be of type string and cannot be nullable for conversion to Avro"); } - return buildTypeSchema(builder.values(), valueField, namespace); + return buildTypeSchema(builder.values(), valueField, namespace, dictionaries); } private static T buildBaseTypeSchema( - SchemaBuilder.BaseTypeBuilder builder, Field field, String namespace) { + SchemaBuilder.BaseTypeBuilder builder, + Field field, + String namespace, + DictionaryProvider dictionaries) { ArrowType.ArrowTypeID typeID = field.getType().getTypeID(); @@ -269,6 +313,33 @@ private static T buildBaseTypeSchema( return builder.booleanType(); case Int: + if (field.getDictionary() != null) { + if (dictionaries == null) { + throw new IllegalArgumentException( + "Field references a dictionary but no dictionaries were provided: " + + field.getName()); + } + Dictionary dictionary = dictionaries.lookup(field.getDictionary().getId()); + if (dictionary == null) { + throw new IllegalArgumentException( + "Field references a dictionary that does not exist: " + + field.getName() + + ", dictionary ID = " + + field.getDictionary().getId()); + } + if (dictionaryIsValidEnum(dictionary)) { + String[] symbols = dictionarySymbols(dictionary); + return builder.enumeration(field.getName()).symbols(symbols); + } else { + Field decodedField = + new Field( + field.getName(), + dictionary.getVector().getField().getFieldType(), + dictionary.getVector().getField().getChildren()); + return buildBaseTypeSchema(builder, decodedField, namespace, dictionaries); + } + } + ArrowType.Int intType = (ArrowType.Int) field.getType(); if (intType.getBitWidth() > 32 || (intType.getBitWidth() == 32 && !intType.getIsSigned())) { return builder.longType(); @@ -328,7 +399,7 @@ private static T buildBaseTypeSchema( String childNamespace = namespace == null ? field.getName() : namespace + "." + field.getName(); return buildRecordSchema( - builder.record(field.getName()), field.getChildren(), childNamespace); + builder.record(field.getName()), field.getChildren(), childNamespace, dictionaries); case List: case FixedSizeList: @@ -339,13 +410,13 @@ private static T buildBaseTypeSchema( new Field("item", itemField.getFieldType(), itemField.getChildren()); Field safeListField = new Field(field.getName(), field.getFieldType(), List.of(safeItemField)); - return buildArraySchema(builder.array(), safeListField, namespace); + return buildArraySchema(builder.array(), safeListField, namespace, dictionaries); } else { - return buildArraySchema(builder.array(), field, namespace); + return buildArraySchema(builder.array(), field, namespace, dictionaries); } case Map: - return buildMapSchema(builder.map(), field, namespace); + return buildMapSchema(builder.map(), field, namespace, dictionaries); default: throw new IllegalArgumentException( @@ -354,9 +425,12 @@ private static T buildBaseTypeSchema( } private static T addTypesToUnion( - SchemaBuilder.UnionAccumulator accumulator, List unionFields, String namespace) { + SchemaBuilder.UnionAccumulator accumulator, + List unionFields, + String namespace, + DictionaryProvider dictionaries) { for (var field : unionFields) { - accumulator = buildBaseTypeSchema(accumulator.and(), field, namespace); + accumulator = buildBaseTypeSchema(accumulator.and(), field, namespace, dictionaries); } return accumulator.endUnion(); } @@ -373,30 +447,78 @@ private static LogicalType timestampLogicalType(ArrowType.Timestamp timestampTyp } } + private static boolean dictionaryIsValidEnum(Dictionary dictionary) { + + if (dictionary.getVectorType().getTypeID() != ArrowType.ArrowTypeID.Utf8) { + return false; + } + + VarCharVector vector = (VarCharVector) dictionary.getVector(); + Set symbols = new HashSet<>(); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) return false; + Text text = vector.getObject(i); + if (text == null) return false; + String symbol = text.toString(); + if (!ENUM_REGEX.matcher(symbol).matches()) return false; + if (symbols.contains(symbol)) return false; + symbols.add(symbol); + } + + return true; + } + + private static String[] dictionarySymbols(Dictionary dictionary) { + + VarCharVector vector = (VarCharVector) dictionary.getVector(); + String[] symbols = new String[vector.getValueCount()]; + + for (int i = 0; i < vector.getValueCount(); i++) { + Text text = vector.getObject(i); + // This should never happen if dictionaryIsValidEnum() succeeded + if (text == null) throw new IllegalArgumentException("Illegal null value in enum"); + symbols[i] = text.toString(); + } + + return symbols; + } + + private static final Pattern ENUM_REGEX = Pattern.compile("^[A-Za-z_][A-Za-z0-9_]*$"); + /** * Create a composite Avro producer for a set of field vectors (typically the root set of a VSR). * * @param vectors The vectors that will be used to produce Avro data * @return The resulting composite Avro producer */ - public static CompositeAvroProducer createCompositeProducer(List vectors) { + public static CompositeAvroProducer createCompositeProducer( + List vectors, DictionaryProvider dictionaries) { List> producers = new ArrayList<>(vectors.size()); for (FieldVector vector : vectors) { - BaseAvroProducer producer = createProducer(vector); + BaseAvroProducer producer = createProducer(vector, dictionaries); producers.add(producer); } return new CompositeAvroProducer(producers); } - private static BaseAvroProducer createProducer(FieldVector vector) { + /** Overload provided for convenience, sets dictionaries = null. */ + public static CompositeAvroProducer createCompositeProducer(List vectors) { + + return createCompositeProducer(vectors, null); + } + + private static BaseAvroProducer createProducer( + FieldVector vector, DictionaryProvider dictionaries) { boolean nullable = vector.getField().isNullable(); - return createProducer(vector, nullable); + return createProducer(vector, nullable, dictionaries); } - private static BaseAvroProducer createProducer(FieldVector vector, boolean nullable) { + private static BaseAvroProducer createProducer( + FieldVector vector, boolean nullable, DictionaryProvider dictionaries) { Preconditions.checkNotNull(vector, "Arrow vector object can't be null"); @@ -405,10 +527,32 @@ private static BaseAvroProducer createProducer(FieldVector vector, boolean nu // Avro understands nullable types as a union of type | null // Most nullable fields in a VSR will not be unions, so provide a special wrapper if (nullable && minorType != Types.MinorType.UNION) { - final BaseAvroProducer innerProducer = createProducer(vector, false); + final BaseAvroProducer innerProducer = createProducer(vector, false, dictionaries); return new AvroNullableProducer<>(innerProducer); } + if (vector.getField().getDictionary() != null) { + if (dictionaries == null) { + throw new IllegalArgumentException( + "Field references a dictionary but no dictionaries were provided: " + + vector.getField().getName()); + } + Dictionary dictionary = dictionaries.lookup(vector.getField().getDictionary().getId()); + if (dictionary == null) { + throw new IllegalArgumentException( + "Field references a dictionary that does not exist: " + + vector.getField().getName() + + ", dictionary ID = " + + vector.getField().getDictionary().getId()); + } + // If a field is dictionary-encoded but cannot be represented as an Avro enum, + // then decode it before writing + if (!dictionaryIsValidEnum(dictionary)) { + FieldVector decodedVector = (FieldVector) DictionaryEncoder.decode(vector, dictionary); + return createProducer(decodedVector, nullable, dictionaries); + } + } + switch (minorType) { case NULL: return new AvroNullProducer((NullVector) vector); @@ -486,21 +630,23 @@ private static BaseAvroProducer createProducer(FieldVector vector, boolean nu Producer[] childProducers = new Producer[childVectors.size()]; for (int i = 0; i < childVectors.size(); i++) { FieldVector childVector = childVectors.get(i); - childProducers[i] = createProducer(childVector, childVector.getField().isNullable()); + childProducers[i] = + createProducer(childVector, childVector.getField().isNullable(), dictionaries); } return new AvroStructProducer(structVector, childProducers); case LIST: ListVector listVector = (ListVector) vector; FieldVector itemVector = listVector.getDataVector(); - Producer itemProducer = createProducer(itemVector, itemVector.getField().isNullable()); + Producer itemProducer = + createProducer(itemVector, itemVector.getField().isNullable(), dictionaries); return new AvroListProducer(listVector, itemProducer); case FIXED_SIZE_LIST: FixedSizeListVector fixedListVector = (FixedSizeListVector) vector; FieldVector fixedItemVector = fixedListVector.getDataVector(); Producer fixedItemProducer = - createProducer(fixedItemVector, fixedItemVector.getField().isNullable()); + createProducer(fixedItemVector, fixedItemVector.getField().isNullable(), dictionaries); return new AvroFixedSizeListProducer(fixedListVector, fixedItemProducer); case MAP: @@ -514,7 +660,7 @@ private static BaseAvroProducer createProducer(FieldVector vector, boolean nu FieldVector valueVector = entryVector.getChildrenFromFields().get(1); Producer keyProducer = new AvroStringProducer(keyVector); Producer valueProducer = - createProducer(valueVector, valueVector.getField().isNullable()); + createProducer(valueVector, valueVector.getField().isNullable(), dictionaries); Producer entryProducer = new AvroStructProducer(entryVector, new Producer[] {keyProducer, valueProducer}); return new AvroMapProducer(mapVector, entryProducer); From b32070adfa898fbdc3f9e881a6481b0393d89834 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:07:03 +0100 Subject: [PATCH 02/11] Round tip schema test for enums --- .../adapter/avro/RoundTripSchemaTest.java | 63 ++++++++++++++++++- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java index 864e2c8b59..37c0b4d9fe 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripSchemaTest.java @@ -21,27 +21,50 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.avro.Schema; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class RoundTripSchemaTest { private void doRoundTripTest(List fields) { + doRoundTripTest(fields, null); + } - AvroToArrowConfig config = new AvroToArrowConfig(null, 1, null, Collections.emptySet(), false); + private void doRoundTripTest(List fields, DictionaryProvider dictionaries) { - Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord"); + DictionaryProvider.MapDictionaryProvider decodeDictionaries = + new DictionaryProvider.MapDictionaryProvider(); + AvroToArrowConfig decodeConfig = + new AvroToArrowConfig(null, 1, decodeDictionaries, Collections.emptySet(), false); + + Schema avroSchema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); org.apache.arrow.vector.types.pojo.Schema arrowSchema = - AvroToArrowUtils.createArrowSchema(avroSchema, config); + AvroToArrowUtils.createArrowSchema(avroSchema, decodeConfig); // Compare string representations - equality not defined for logical types assertEquals(fields, arrowSchema.getFields()); + + for (int i = 0; i < fields.size(); i++) { + Field field = fields.get(i); + Field rtField = arrowSchema.getFields().get(i); + if (field.getDictionary() != null) { + // Dictionary content is not decoded until the data is consumed + Assertions.assertNotNull(rtField.getDictionary()); + } + } } // Schema round trip for primitive types, nullable and non-nullable @@ -440,4 +463,38 @@ public void testRoundTripStructType() { doRoundTripTest(fields); } + + @Test + public void testRoundTripEnumType() { + + BufferAllocator allocator = new RootAllocator(); + + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType( + true, + new ArrowType.Int(8, true), + new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))), + null)); + + doRoundTripTest(fields, dictionaries); + } } From a686768f1181923236731dc41ccb1d2ba4d64cee Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:08:32 +0100 Subject: [PATCH 03/11] Round tip data test for enums --- .../arrow/adapter/avro/RoundTripDataTest.java | 110 ++++++++++++++++-- 1 file changed, 102 insertions(+), 8 deletions(-) diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java index 85e6a960b0..ceaf59aa72 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/RoundTripDataTest.java @@ -52,6 +52,7 @@ import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.TimeStampNanoTZVector; import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -60,10 +61,14 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.avro.Schema; @@ -78,16 +83,21 @@ public class RoundTripDataTest { @TempDir public static File TMP; - private static AvroToArrowConfig basicConfig(BufferAllocator allocator) { - return new AvroToArrowConfig(allocator, 1000, null, Collections.emptySet(), false); + private static AvroToArrowConfig basicConfig( + BufferAllocator allocator, DictionaryProvider.MapDictionaryProvider dictionaries) { + return new AvroToArrowConfig(allocator, 1000, dictionaries, Collections.emptySet(), false); } private static VectorSchemaRoot readDataFile( - Schema schema, File dataFile, BufferAllocator allocator) throws Exception { + Schema schema, + File dataFile, + BufferAllocator allocator, + DictionaryProvider.MapDictionaryProvider dictionaries) + throws Exception { try (FileInputStream fis = new FileInputStream(dataFile)) { BinaryDecoder decoder = new DecoderFactory().directBinaryDecoder(fis, null); - return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator)); + return AvroToArrow.avroToArrow(schema, decoder, basicConfig(allocator, dictionaries)); } } @@ -95,11 +105,22 @@ private static void roundTripTest( VectorSchemaRoot root, BufferAllocator allocator, File dataFile, int rowCount) throws Exception { + roundTripTest(root, allocator, dataFile, rowCount, null); + } + + private static void roundTripTest( + VectorSchemaRoot root, + BufferAllocator allocator, + File dataFile, + int rowCount, + DictionaryProvider dictionaries) + throws Exception { + // Write an AVRO block using the producer classes try (FileOutputStream fos = new FileOutputStream(dataFile)) { BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); CompositeAvroProducer producer = - ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors()); + ArrowToAvroUtils.createCompositeProducer(root.getFieldVectors(), dictionaries); for (int row = 0; row < rowCount; row++) { producer.produce(encoder); } @@ -107,10 +128,14 @@ private static void roundTripTest( } // Generate AVRO schema - Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields()); + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); + + DictionaryProvider.MapDictionaryProvider roundTripDictionaries = + new DictionaryProvider.MapDictionaryProvider(); // Read back in and compare - try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) { + try (VectorSchemaRoot roundTrip = + readDataFile(schema, dataFile, allocator, roundTripDictionaries)) { assertEquals(root.getSchema(), roundTrip.getSchema()); assertEquals(rowCount, roundTrip.getRowCount()); @@ -119,6 +144,21 @@ private static void roundTripTest( for (int row = 0; row < rowCount; row++) { assertEquals(root.getVector(0).getObject(row), roundTrip.getVector(0).getObject(row)); } + + if (dictionaries != null) { + for (long id : dictionaries.getDictionaryIds()) { + Dictionary originalDictionary = dictionaries.lookup(id); + Dictionary roundTripDictionary = roundTripDictionaries.lookup(id); + assertEquals( + originalDictionary.getVector().getValueCount(), + roundTripDictionary.getVector().getValueCount()); + for (int j = 0; j < originalDictionary.getVector().getValueCount(); j++) { + assertEquals( + originalDictionary.getVector().getObject(j), + roundTripDictionary.getVector().getObject(j)); + } + } + } } } @@ -141,7 +181,7 @@ private static void roundTripByteArrayTest( Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields()); // Read back in and compare - try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator)) { + try (VectorSchemaRoot roundTrip = readDataFile(schema, dataFile, allocator, null)) { assertEquals(root.getSchema(), roundTrip.getSchema()); assertEquals(rowCount, roundTrip.getRowCount()); @@ -1603,4 +1643,58 @@ public void testRoundTripNullableStructs() throws Exception { roundTripTest(root, allocator, dataFile, rowCount); } } + + @Test + public void testRoundTripEnum() throws Exception { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // For simplicity, ensure the index type matches what will be decoded during Avro enum decoding + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + // Field definition + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector stringVector = + new VarCharVector(new Field("enumField", stringField, null), allocator); + stringVector.allocateNew(10); + stringVector.setSafe(0, "apple".getBytes()); + stringVector.setSafe(1, "banana".getBytes()); + stringVector.setSafe(2, "cherry".getBytes()); + stringVector.setSafe(3, "cherry".getBytes()); + stringVector.setSafe(4, "apple".getBytes()); + stringVector.setSafe(5, "banana".getBytes()); + stringVector.setSafe(6, "apple".getBytes()); + stringVector.setSafe(7, "cherry".getBytes()); + stringVector.setSafe(8, "banana".getBytes()); + stringVector.setSafe(9, "apple".getBytes()); + stringVector.setValueCount(10); + + TinyIntVector encodedVector = + (TinyIntVector) DictionaryEncoder.encode(stringVector, dictionary); + + // Set up VSR + List vectors = Arrays.asList(encodedVector); + int rowCount = 10; + + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + + File dataFile = new File(TMP, "testRoundTripEnums.avro"); + + roundTripTest(root, allocator, dataFile, rowCount, dictionaries); + } + } } From 64d5b7a52757b45bcbf27383d29f4037d92bd3dd Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:09:48 +0100 Subject: [PATCH 04/11] Avro write test for enums --- .../adapter/avro/ArrowToAvroDataTest.java | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java index 2d70b45021..7b6ede7006 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java @@ -76,10 +76,14 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.JsonStringArrayList; @@ -2817,4 +2821,184 @@ record = datumReader.read(record, decoder); } } } + + @Test + public void testWriteDictEnumEncoded() throws Exception { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + // Field definition + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector stringVector = + new VarCharVector(new Field("enumField", stringField, null), allocator); + stringVector.allocateNew(10); + stringVector.setSafe(0, "apple".getBytes()); + stringVector.setSafe(1, "banana".getBytes()); + stringVector.setSafe(2, "cherry".getBytes()); + stringVector.setSafe(3, "cherry".getBytes()); + stringVector.setSafe(4, "apple".getBytes()); + stringVector.setSafe(5, "banana".getBytes()); + stringVector.setSafe(6, "apple".getBytes()); + stringVector.setSafe(7, "cherry".getBytes()); + stringVector.setSafe(8, "banana".getBytes()); + stringVector.setSafe(9, "apple".getBytes()); + stringVector.setValueCount(10); + + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary); + + // Set up VSR + List vectors = Arrays.asList(encodedVector); + int rowCount = 10; + + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + + File dataFile = new File(TMP, "testWriteEnumEncoded.avro"); + + // Write an AVRO block using the producer classes + try (FileOutputStream fos = new FileOutputStream(dataFile)) { + BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); + CompositeAvroProducer producer = + ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); + for (int row = 0; row < rowCount; row++) { + producer.produce(encoder); + } + encoder.flush(); + } + + // Set up reading the AVRO block as a GenericRecord + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); + GenericDatumReader datumReader = new GenericDatumReader<>(schema); + + try (InputStream inputStream = new FileInputStream(dataFile)) { + + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + GenericRecord record = null; + + // Read and check values + for (int row = 0; row < rowCount; row++) { + record = datumReader.read(record, decoder); + // Values read from Avro should be the decoded enum values + assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString()); + } + } + } + } + + @Test + public void testWriteEnumDecoded() throws Exception { + + // Dict encoded fields that are not valid Avro enums should be decoded on write + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "passion fruit".getBytes()); // spaced not allowed + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + FieldType dictionaryField2 = new FieldType(false, new ArrowType.Int(64, true), null); + BigIntVector dictionaryVector2 = + new BigIntVector(new Field("dictionary2", dictionaryField2, null), allocator); + + dictionaryVector2.allocateNew(3); + dictionaryVector2.set(0, 0L); + dictionaryVector2.set(1, 1L); + dictionaryVector2.set(2, 2L); + dictionaryVector2.setValueCount(3); + + Dictionary dictionary2 = + new Dictionary(dictionaryVector2, new DictionaryEncoding(2L, false, null)); + + DictionaryProvider dictionaries = + new DictionaryProvider.MapDictionaryProvider(dictionary, dictionary2); + + // Field definition + FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector stringVector = + new VarCharVector(new Field("enumField", stringField, null), allocator); + stringVector.allocateNew(10); + stringVector.setSafe(0, "passion fruit".getBytes()); + stringVector.setSafe(1, "banana".getBytes()); + stringVector.setSafe(2, "cherry".getBytes()); + stringVector.setSafe(3, "cherry".getBytes()); + stringVector.setSafe(4, "passion fruit".getBytes()); + stringVector.setSafe(5, "banana".getBytes()); + stringVector.setSafe(6, "passion fruit".getBytes()); + stringVector.setSafe(7, "cherry".getBytes()); + stringVector.setSafe(8, "banana".getBytes()); + stringVector.setSafe(9, "passion fruit".getBytes()); + stringVector.setValueCount(10); + + FieldType longField = new FieldType(false, new ArrowType.Int(64, true), null); + BigIntVector longVector = new BigIntVector(new Field("enumField2", longField, null), allocator); + longVector.allocateNew(10); + for (int i = 0; i < 10; i++) { + longVector.setSafe(i, (long) i % 3); + } + longVector.setValueCount(10); + + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary); + IntVector encodedVector2 = (IntVector) DictionaryEncoder.encode(longVector, dictionary2); + + // Set up VSR + List vectors = Arrays.asList(encodedVector, encodedVector2); + int rowCount = 10; + + try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + + File dataFile = new File(TMP, "testWriteEnumDecodedavro"); + + // Write an AVRO block using the producer classes + try (FileOutputStream fos = new FileOutputStream(dataFile)) { + BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); + CompositeAvroProducer producer = + ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); + for (int row = 0; row < rowCount; row++) { + producer.produce(encoder); + } + encoder.flush(); + } + + // Set up reading the AVRO block as a GenericRecord + Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); + GenericDatumReader datumReader = new GenericDatumReader<>(schema); + + try (InputStream inputStream = new FileInputStream(dataFile)) { + + BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); + GenericRecord record = null; + + // Read and check values + for (int row = 0; row < rowCount; row++) { + record = datumReader.read(record, decoder); + assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString()); + assertEquals(longVector.getObject(row), record.get("enumField2")); + } + } + } + } } From 58b4f73be683dfd63c970d0e5122aaca74f0cbaf Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:17:52 +0100 Subject: [PATCH 05/11] Avro write test for enums schema --- .../adapter/avro/ArrowToAvroSchemaTest.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java index d3e12e763a..496534f306 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java @@ -20,11 +20,17 @@ import java.util.Arrays; import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.avro.LogicalTypes; @@ -1389,4 +1395,46 @@ public void testConvertUnionTypes() { Schema.Type.STRING, schema.getField("nullableDenseUnionField").schema().getTypes().get(3).getType()); } + + @Test + public void testWriteDictEnumEncoded() { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null), + null)); + + Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); + + assertEquals(Schema.Type.RECORD, schema.getType()); + assertEquals(1, schema.getFields().size()); + + Schema.Field enumField = schema.getField("enumField"); + + assertEquals(Schema.Type.ENUM, enumField.schema().getType()); + assertEquals(3, enumField.schema().getEnumSymbols().size()); + assertEquals("apple", enumField.schema().getEnumSymbols().get(0)); + assertEquals("banana", enumField.schema().getEnumSymbols().get(1)); + assertEquals("cherry", enumField.schema().getEnumSymbols().get(2)); + } } From a917b44fd5c222145cddf3e38a7c7b2718af77c3 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:18:17 +0100 Subject: [PATCH 06/11] Add a doc comment about enum dict encoding --- .../org/apache/arrow/adapter/avro/ArrowToAvroUtils.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java index 67f0c66112..9e698961ac 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java @@ -169,10 +169,10 @@ public class ArrowToAvroUtils { * may be nullable. Record types must contain at least one child field and cannot contain multiple * fields with the same name * - *

String fields that are dictionary-encoded will be represented as an Avro enum, so long as all - * the values meet the restrictions on Avro enums (non-null, valid identifiers). Other data types - * that are dictionary encoded, or string fields that do not meet the avro requirements, will be output - * as their decoded type. + *

String fields that are dictionary-encoded will be represented as an Avro enum, so long as + * all the values meet the restrictions on Avro enums (non-null, valid identifiers). Other data + * types that are dictionary encoded, or string fields that do not meet the avro requirements, + * will be output as their decoded type. * * @param arrowFields The arrow fields used to generate the Avro schema * @param typeName Name of the top level Avro record type From f4134d10cc59d866bbf3f31acbf7de73631d8fe5 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sun, 1 Jun 2025 18:21:12 +0100 Subject: [PATCH 07/11] Fix code style checks --- .../arrow/adapter/avro/ArrowToAvroUtils.java | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java index 9e698961ac..d545aa2eaa 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java @@ -457,12 +457,20 @@ private static boolean dictionaryIsValidEnum(Dictionary dictionary) { Set symbols = new HashSet<>(); for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) return false; + if (vector.isNull(i)) { + return false; + } Text text = vector.getObject(i); - if (text == null) return false; + if (text == null) { + return false; + } String symbol = text.toString(); - if (!ENUM_REGEX.matcher(symbol).matches()) return false; - if (symbols.contains(symbol)) return false; + if (!ENUM_REGEX.matcher(symbol).matches()) { + return false; + } + if (symbols.contains(symbol)) { + return false; + } symbols.add(symbol); } @@ -477,7 +485,9 @@ private static String[] dictionarySymbols(Dictionary dictionary) { for (int i = 0; i < vector.getValueCount(); i++) { Text text = vector.getObject(i); // This should never happen if dictionaryIsValidEnum() succeeded - if (text == null) throw new IllegalArgumentException("Illegal null value in enum"); + if (text == null) { + throw new IllegalArgumentException("Illegal null value in enum"); + } symbols[i] = text.toString(); } From 735d3c15dc68d11bdfcb7208bdf8c1de1719d5a3 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sat, 7 Jun 2025 21:43:17 +0100 Subject: [PATCH 08/11] Use dictionary decoding producer to handle dictionary fields that are not valid avro enums --- .../arrow/adapter/avro/ArrowToAvroUtils.java | 12 +++-- .../avro/producers/AvroEnumProducer.java | 10 ++--- .../producers/DictionaryDecodingProducer.java | 45 +++++++++++++++++++ 3 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java index d545aa2eaa..e09b99f670 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/ArrowToAvroUtils.java @@ -24,6 +24,7 @@ import org.apache.arrow.adapter.avro.producers.AvroBigIntProducer; import org.apache.arrow.adapter.avro.producers.AvroBooleanProducer; import org.apache.arrow.adapter.avro.producers.AvroBytesProducer; +import org.apache.arrow.adapter.avro.producers.AvroEnumProducer; import org.apache.arrow.adapter.avro.producers.AvroFixedSizeBinaryProducer; import org.apache.arrow.adapter.avro.producers.AvroFixedSizeListProducer; import org.apache.arrow.adapter.avro.producers.AvroFloat2Producer; @@ -44,6 +45,7 @@ import org.apache.arrow.adapter.avro.producers.AvroUint8Producer; import org.apache.arrow.adapter.avro.producers.BaseAvroProducer; import org.apache.arrow.adapter.avro.producers.CompositeAvroProducer; +import org.apache.arrow.adapter.avro.producers.DictionaryDecodingProducer; import org.apache.arrow.adapter.avro.producers.Producer; import org.apache.arrow.adapter.avro.producers.logical.AvroDateDayProducer; import org.apache.arrow.adapter.avro.producers.logical.AvroDateMilliProducer; @@ -62,6 +64,7 @@ import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecProducer; import org.apache.arrow.adapter.avro.producers.logical.AvroTimestampSecTzProducer; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -100,7 +103,6 @@ import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.dictionary.Dictionary; -import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; @@ -557,9 +559,11 @@ private static BaseAvroProducer createProducer( } // If a field is dictionary-encoded but cannot be represented as an Avro enum, // then decode it before writing - if (!dictionaryIsValidEnum(dictionary)) { - FieldVector decodedVector = (FieldVector) DictionaryEncoder.decode(vector, dictionary); - return createProducer(decodedVector, nullable, dictionaries); + if (dictionaryIsValidEnum(dictionary)) { + return new AvroEnumProducer((BaseIntVector) vector); + } else { + BaseAvroProducer dictProducer = createProducer(dictionary.getVector(), false, null); + return new DictionaryDecodingProducer<>((BaseIntVector) vector, dictProducer); } } diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java index 068566493e..2f63654c63 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java @@ -17,22 +17,22 @@ package org.apache.arrow.adapter.avro.producers; import java.io.IOException; -import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.BaseIntVector; import org.apache.avro.io.Encoder; /** - * Producer that produces enum values from a dictionary-encoded {@link IntVector}, writes data to an + * Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data to an * Avro encoder. */ -public class AvroEnumProducer extends BaseAvroProducer { +public class AvroEnumProducer extends BaseAvroProducer { /** Instantiate an AvroEnumProducer. */ - public AvroEnumProducer(IntVector vector) { + public AvroEnumProducer(BaseIntVector vector) { super(vector); } @Override public void produce(Encoder encoder) throws IOException { - encoder.writeEnum(vector.get(currentIndex++)); + encoder.writeEnum((int) vector.getValueAsLong(currentIndex++)); } } diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java new file mode 100644 index 0000000000..07415ebd1b --- /dev/null +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adapter.avro.producers; + +import java.io.IOException; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.avro.io.Encoder; + +/** + * Producer that produces decoded values from a dictionary-encoded {@link BaseIntVector}, writes data to an + * Avro encoder. + */ +public class DictionaryDecodingProducer + extends BaseAvroProducer { + + private final Producer dictProducer; + + /** Instantiate a DictionaryDecodingProducer. */ + public DictionaryDecodingProducer(BaseIntVector indexVector, Producer dictProducer) { + super(indexVector); + this.dictProducer = dictProducer; + } + + @Override + public void produce(Encoder encoder) throws IOException { + int dicIndex = (int) vector.getValueAsLong(currentIndex++); + dictProducer.setPosition(dicIndex); + dictProducer.produce(encoder); + } +} From 6e9168b930765ce67b642f19001dd451629ca5fb Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sat, 7 Jun 2025 21:46:29 +0100 Subject: [PATCH 09/11] Test fixes --- .../adapter/avro/ArrowToAvroDataTest.java | 106 +----------------- .../adapter/avro/ArrowToAvroSchemaTest.java | 82 ++++++++++++++ 2 files changed, 83 insertions(+), 105 deletions(-) diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java index 7b6ede7006..704edb85e1 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java @@ -2872,8 +2872,7 @@ public void testWriteDictEnumEncoded() throws Exception { // Write an AVRO block using the producer classes try (FileOutputStream fos = new FileOutputStream(dataFile)) { BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); - CompositeAvroProducer producer = - ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); + CompositeAvroProducer producer = ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); for (int row = 0; row < rowCount; row++) { producer.produce(encoder); } @@ -2898,107 +2897,4 @@ record = datumReader.read(record, decoder); } } } - - @Test - public void testWriteEnumDecoded() throws Exception { - - // Dict encoded fields that are not valid Avro enums should be decoded on write - - BufferAllocator allocator = new RootAllocator(); - - // Create a dictionary - FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); - VarCharVector dictionaryVector = - new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); - - dictionaryVector.allocateNew(3); - dictionaryVector.set(0, "passion fruit".getBytes()); // spaced not allowed - dictionaryVector.set(1, "banana".getBytes()); - dictionaryVector.set(2, "cherry".getBytes()); - dictionaryVector.setValueCount(3); - - Dictionary dictionary = - new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); - - FieldType dictionaryField2 = new FieldType(false, new ArrowType.Int(64, true), null); - BigIntVector dictionaryVector2 = - new BigIntVector(new Field("dictionary2", dictionaryField2, null), allocator); - - dictionaryVector2.allocateNew(3); - dictionaryVector2.set(0, 0L); - dictionaryVector2.set(1, 1L); - dictionaryVector2.set(2, 2L); - dictionaryVector2.setValueCount(3); - - Dictionary dictionary2 = - new Dictionary(dictionaryVector2, new DictionaryEncoding(2L, false, null)); - - DictionaryProvider dictionaries = - new DictionaryProvider.MapDictionaryProvider(dictionary, dictionary2); - - // Field definition - FieldType stringField = new FieldType(false, new ArrowType.Utf8(), null); - VarCharVector stringVector = - new VarCharVector(new Field("enumField", stringField, null), allocator); - stringVector.allocateNew(10); - stringVector.setSafe(0, "passion fruit".getBytes()); - stringVector.setSafe(1, "banana".getBytes()); - stringVector.setSafe(2, "cherry".getBytes()); - stringVector.setSafe(3, "cherry".getBytes()); - stringVector.setSafe(4, "passion fruit".getBytes()); - stringVector.setSafe(5, "banana".getBytes()); - stringVector.setSafe(6, "passion fruit".getBytes()); - stringVector.setSafe(7, "cherry".getBytes()); - stringVector.setSafe(8, "banana".getBytes()); - stringVector.setSafe(9, "passion fruit".getBytes()); - stringVector.setValueCount(10); - - FieldType longField = new FieldType(false, new ArrowType.Int(64, true), null); - BigIntVector longVector = new BigIntVector(new Field("enumField2", longField, null), allocator); - longVector.allocateNew(10); - for (int i = 0; i < 10; i++) { - longVector.setSafe(i, (long) i % 3); - } - longVector.setValueCount(10); - - IntVector encodedVector = (IntVector) DictionaryEncoder.encode(stringVector, dictionary); - IntVector encodedVector2 = (IntVector) DictionaryEncoder.encode(longVector, dictionary2); - - // Set up VSR - List vectors = Arrays.asList(encodedVector, encodedVector2); - int rowCount = 10; - - try (VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { - - File dataFile = new File(TMP, "testWriteEnumDecodedavro"); - - // Write an AVRO block using the producer classes - try (FileOutputStream fos = new FileOutputStream(dataFile)) { - BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); - CompositeAvroProducer producer = - ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); - for (int row = 0; row < rowCount; row++) { - producer.produce(encoder); - } - encoder.flush(); - } - - // Set up reading the AVRO block as a GenericRecord - Schema schema = ArrowToAvroUtils.createAvroSchema(root.getSchema().getFields(), dictionaries); - GenericDatumReader datumReader = new GenericDatumReader<>(schema); - - try (InputStream inputStream = new FileInputStream(dataFile)) { - - BinaryDecoder decoder = DecoderFactory.get().binaryDecoder(inputStream, null); - GenericRecord record = null; - - // Read and check values - for (int row = 0; row < rowCount; row++) { - record = datumReader.read(record, decoder); - assertEquals(stringVector.getObject(row).toString(), record.get("enumField").toString()); - assertEquals(longVector.getObject(row), record.get("enumField2")); - } - } - } - } } diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java index 496534f306..1309719fd5 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java @@ -17,11 +17,13 @@ package org.apache.arrow.adapter.avro; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.Arrays; import java.util.List; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -1437,4 +1439,84 @@ public void testWriteDictEnumEncoded() { assertEquals("banana", enumField.schema().getEnumSymbols().get(1)); assertEquals("cherry", enumField.schema().getEnumSymbols().get(2)); } + + @Test + public void testWriteDictEnumInvalid() { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Utf8(), null); + VarCharVector dictionaryVector = + new VarCharVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, "passion fruit".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null), + null)); + + // Dictionary field contains values that are not valid enums + // Should be decoded and output as a string field + + Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); + + assertEquals(Schema.Type.RECORD, schema.getType()); + assertEquals(1, schema.getFields().size()); + + Schema.Field enumField = schema.getField("enumField"); + assertEquals(Schema.Type.STRING, enumField.schema().getType()); + } + + @Test + public void testWriteDictEnumInvalid2() { + + BufferAllocator allocator = new RootAllocator(); + + // Create a dictionary + FieldType dictionaryField = new FieldType(false, new ArrowType.Int(64, true), null); + BigIntVector dictionaryVector = + new BigIntVector(new Field("dictionary", dictionaryField, null), allocator); + + dictionaryVector.allocateNew(3); + dictionaryVector.set(0, 123L); + dictionaryVector.set(1, 456L); + dictionaryVector.set(2, 789L); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = + new Dictionary( + dictionaryVector, new DictionaryEncoding(0L, false, new ArrowType.Int(8, true))); + DictionaryProvider dictionaries = new DictionaryProvider.MapDictionaryProvider(dictionary); + + List fields = + Arrays.asList( + new Field( + "enumField", + new FieldType(false, new ArrowType.Int(8, true), dictionary.getEncoding(), null), + null)); + + // Dictionary field encodes LONG values rather than STRING + // Should be doecded and output as a LONG field + + Schema schema = ArrowToAvroUtils.createAvroSchema(fields, "TestRecord", null, dictionaries); + + assertEquals(Schema.Type.RECORD, schema.getType()); + assertEquals(1, schema.getFields().size()); + + Schema.Field enumField = schema.getField("enumField"); + assertEquals(Schema.Type.LONG, enumField.schema().getType()); + } } From 8a69d1291d005e5358be1beb1f92cca5920716b3 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Sat, 7 Jun 2025 21:48:43 +0100 Subject: [PATCH 10/11] Apply spotless --- .../arrow/adapter/avro/producers/AvroEnumProducer.java | 4 ++-- .../adapter/avro/producers/DictionaryDecodingProducer.java | 6 ++++-- .../org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java | 3 ++- .../apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java | 1 - 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java index 2f63654c63..eebfb7d241 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/AvroEnumProducer.java @@ -21,8 +21,8 @@ import org.apache.avro.io.Encoder; /** - * Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data to an - * Avro encoder. + * Producer that produces enum values from a dictionary-encoded {@link BaseIntVector}, writes data + * to an Avro encoder. */ public class AvroEnumProducer extends BaseAvroProducer { diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java index 07415ebd1b..828ecf241d 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java @@ -22,8 +22,10 @@ import org.apache.avro.io.Encoder; /** - * Producer that produces decoded values from a dictionary-encoded {@link BaseIntVector}, writes data to an - * Avro encoder. + * Producer that produces decoded values from a dictionary-encoded {@link BaseIntVector}, writes + * data to an Avro encoder. + * + * @param Type of the underlying dictionary vector */ public class DictionaryDecodingProducer extends BaseAvroProducer { diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java index 704edb85e1..6d66ee9d45 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroDataTest.java @@ -2872,7 +2872,8 @@ public void testWriteDictEnumEncoded() throws Exception { // Write an AVRO block using the producer classes try (FileOutputStream fos = new FileOutputStream(dataFile)) { BinaryEncoder encoder = new EncoderFactory().directBinaryEncoder(fos, null); - CompositeAvroProducer producer = ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); + CompositeAvroProducer producer = + ArrowToAvroUtils.createCompositeProducer(vectors, dictionaries); for (int row = 0; row < rowCount; row++) { producer.produce(encoder); } diff --git a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java index 1309719fd5..d5e0357a8c 100644 --- a/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java +++ b/adapter/avro/src/test/java/org/apache/arrow/adapter/avro/ArrowToAvroSchemaTest.java @@ -17,7 +17,6 @@ package org.apache.arrow.adapter.avro; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.Arrays; import java.util.List; From a870663e1d41f60fc99575db6aac59ce539ee301 Mon Sep 17 00:00:00 2001 From: Martin Traverse Date: Mon, 7 Jul 2025 18:59:05 +0100 Subject: [PATCH 11/11] Update doc comment for DictionaryDecodingProducer --- .../adapter/avro/producers/DictionaryDecodingProducer.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java index 828ecf241d..afeba08511 100644 --- a/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java +++ b/adapter/avro/src/main/java/org/apache/arrow/adapter/avro/producers/DictionaryDecodingProducer.java @@ -22,8 +22,8 @@ import org.apache.avro.io.Encoder; /** - * Producer that produces decoded values from a dictionary-encoded {@link BaseIntVector}, writes - * data to an Avro encoder. + * Producer that decodes values from a dictionary-encoded {@link FieldVector}, writes the resulting + * values to an Avro encoder. * * @param Type of the underlying dictionary vector */