From 01cdcc33f03a5ca57e429e18f60eab1b64c15fc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Finn=20V=C3=B6lkel?= Date: Mon, 20 May 2024 10:56:30 +0200 Subject: [PATCH 1/2] Adding ExtensionWriter and fix for #41728 --- .../templates/AbstractFieldWriter.java | 12 +++++ .../AbstractPromotableFieldWriter.java | 11 ++++ .../main/codegen/templates/BaseWriter.java | 5 ++ .../main/codegen/templates/StructWriters.java | 26 +++++++++ .../templates/UnionFixedSizeListWriter.java | 12 +++++ .../codegen/templates/UnionListWriter.java | 12 +++++ .../main/codegen/templates/UnionWriter.java | 20 +++++++ .../vector/complex/impl/PromotableWriter.java | 5 ++ .../complex/impl/UnionExtensionWriter.java | 54 +++++++++++++++++++ .../vector/complex/writer/FieldWriter.java | 3 +- 10 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java diff --git a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java index 6c2368117f7..f47dec1a522 100644 --- a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -214,6 +214,18 @@ public MapWriter map(boolean keysSorted) { return null; } + @Override + public ExtensionWriter extension(String name, ArrowType arrowType) { + fail("Extension"); + return null; + } + + @Override + public ExtensionWriter extension(ArrowType arrowType) { + fail("Extension"); + return null; + } + @Override public MapWriter map(String name, boolean keysSorted) { fail("Map"); diff --git a/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java index 59f9fb5b809..5bc956de12c 100644 --- a/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java @@ -272,6 +272,11 @@ public MapWriter map() { return getWriter(MinorType.LIST).map(); } + @Override + public ExtensionWriter extension(ArrowType arrowType) { + return getWriter(MinorType.EXTENSIONTYPE).extension(arrowType); + } + @Override public MapWriter map(boolean keysSorted) { return getWriter(MinorType.MAP, new ArrowType.Map(keysSorted)); @@ -292,6 +297,12 @@ public MapWriter map(String name) { return getWriter(MinorType.STRUCT).map(name); } + @Override + public ExtensionWriter extension(String name, ArrowType arrowType) { + return getWriter(MinorType.EXTENSIONTYPE).extension(name, arrowType); + } + + @Override public MapWriter map(String name, boolean keysSorted) { return getWriter(MinorType.STRUCT).map(name, keysSorted); diff --git a/java/vector/src/main/codegen/templates/BaseWriter.java b/java/vector/src/main/codegen/templates/BaseWriter.java index 35df256b324..20f876b3a27 100644 --- a/java/vector/src/main/codegen/templates/BaseWriter.java +++ b/java/vector/src/main/codegen/templates/BaseWriter.java @@ -61,6 +61,7 @@ public interface StructWriter extends BaseWriter { void copyReaderToField(String name, FieldReader reader); StructWriter struct(String name); + ExtensionWriter extension(String name, ArrowType arrowType); ListWriter list(String name); MapWriter map(String name); MapWriter map(String name, boolean keysSorted); @@ -68,6 +69,9 @@ public interface StructWriter extends BaseWriter { void end(); } + public interface ExtensionWriter extends StructWriter { + } + public interface ListWriter extends BaseWriter { void startList(); void endList(); @@ -75,6 +79,7 @@ public interface ListWriter extends BaseWriter { ListWriter list(); MapWriter map(); MapWriter map(boolean keysSorted); + ExtensionWriter extension(ArrowType arrowType); void copyReader(FieldReader reader); <#list vv.types as type><#list type.minor as minor> diff --git a/java/vector/src/main/codegen/templates/StructWriters.java b/java/vector/src/main/codegen/templates/StructWriters.java index b6dd2b75c52..3cce8ef1ff7 100644 --- a/java/vector/src/main/codegen/templates/StructWriters.java +++ b/java/vector/src/main/codegen/templates/StructWriters.java @@ -73,6 +73,9 @@ public class ${mode}StructWriter extends AbstractFieldWriter { map(child.getName(), arrowType.getKeysSorted()); break; } + case EXTENSIONTYPE: + extension(child.getName(), child.getType()); + break; case DENSEUNION: { FieldType fieldType = new FieldType(addVectorAsNullable, MinorType.DENSEUNION.getType(), null, null); DenseUnionWriter writer = new DenseUnionWriter(container.addOrGet(child.getName(), fieldType, DenseUnionVector.class), getNullableStructWriterFactory()); @@ -132,6 +135,29 @@ public Field getField() { return container.getField(); } + @Override + public ExtensionWriter extension(String name, ArrowType arrowType) { + String finalName = handleCase(name); + FieldWriter writer = fields.get(finalName); + if(writer == null){ + int vectorCount=container.size(); + FieldType fieldType = new FieldType(addVectorAsNullable, arrowType, null, null); + ExtensionTypeVector vector = container.addOrGet(name, fieldType, ExtensionTypeVector.class); + writer = new PromotableWriter(vector, container, getNullableStructWriterFactory()); + if(vectorCount != container.size()) { + writer.allocate(); + } + writer.setPosition(idx()); + fields.put(finalName, writer); + } else { + if (writer instanceof PromotableWriter) { + // ensure writers are initialized + ((PromotableWriter)writer).getWriter(MinorType.EXTENSIONTYPE, arrowType); + } + } + return (ExtensionWriter) writer; + } + @Override public StructWriter struct(String name) { String finalName = handleCase(name); diff --git a/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java b/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java index 3436e3a9676..a79512ad299 100644 --- a/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionFixedSizeListWriter.java @@ -192,6 +192,18 @@ public MapWriter map(String name, boolean keysSorted) { return mapWriter; } + @Override + public ExtensionWriter extension(ArrowType arrowType) { + writer.extension(arrowType); + return writer; + } + + @Override + public ExtensionWriter extension(String name, ArrowType arrowType) { + ExtensionWriter extensionWriter = writer.extension(name, arrowType); + return extensionWriter; + } + @Override public void startList() { int start = vector.startNewValue(idx()); diff --git a/java/vector/src/main/codegen/templates/UnionListWriter.java b/java/vector/src/main/codegen/templates/UnionListWriter.java index 5c0565ee271..807d2a0938b 100644 --- a/java/vector/src/main/codegen/templates/UnionListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionListWriter.java @@ -179,6 +179,18 @@ public MapWriter map(String name, boolean keysSorted) { return mapWriter; } + @Override + public ExtensionWriter extension(ArrowType arrowType) { + writer.extension(arrowType); + return writer; + } + + @Override + public ExtensionWriter extension(String name, ArrowType arrowType) { + ExtensionWriter extensionWriter = writer.extension(name, arrowType); + return extensionWriter; + } + <#if listName == "LargeList"> @Override public void startList() { diff --git a/java/vector/src/main/codegen/templates/UnionWriter.java b/java/vector/src/main/codegen/templates/UnionWriter.java index 08dbf24324b..606d827c67d 100644 --- a/java/vector/src/main/codegen/templates/UnionWriter.java +++ b/java/vector/src/main/codegen/templates/UnionWriter.java @@ -165,6 +165,10 @@ private MapWriter getMapWriter(ArrowType arrowType) { } return mapWriter; } + + private ExtensionWriter getExtensionWriter(ArrowType arrowType) { + throw new UnsupportedOperationException("ExtensionTypes are not supported yet."); + } public MapWriter asMap(ArrowType arrowType) { data.setType(idx(), MinorType.MAP); @@ -183,6 +187,8 @@ BaseWriter getWriter(MinorType minorType, ArrowType arrowType) { return getListWriter(); case MAP: return getMapWriter(arrowType); + case EXTENSIONTYPE: + return getExtensionWriter(arrowType); <#list vv.types as type> <#list type.minor as minor> <#assign name = minor.class?cap_first /> @@ -402,6 +408,20 @@ public MapWriter map(String name, boolean keysSorted) { return getStructWriter().map(name, keysSorted); } + @Override + public ExtensionWriter extension(ArrowType arrowType) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + getListWriter().setPosition(idx()); + return getListWriter().extension(arrowType); + } + + @Override + public ExtensionWriter extension(String name, ArrowType arrowType) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + getStructWriter().setPosition(idx()); + return getStructWriter().extension(name, arrowType); + } + <#list vv.types as type><#list type.minor as minor> <#assign lowerName = minor.class?uncap_first /> <#if lowerName == "int" ><#assign lowerName = "integer" /> diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 7f724829ef1..139296579dc 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -22,6 +22,7 @@ import java.util.Locale; import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.NullVector; import org.apache.arrow.vector.ValueVector; @@ -224,6 +225,9 @@ private void setWriter(ValueVector v) { case UNION: writer = new UnionWriter((UnionVector) vector, nullableStructWriterFactory); break; + case EXTENSIONTYPE: + writer = new UnionExtensionWriter((ExtensionTypeVector) vector); + break; default: writer = type.getNewFieldWriter(vector); break; @@ -255,6 +259,7 @@ private boolean requiresArrowType(MinorType type) { type == MinorType.MAP || type == MinorType.DURATION || type == MinorType.FIXEDSIZEBINARY || + type == MinorType.EXTENSIONTYPE || (type.name().startsWith("TIMESTAMP") && type.name().endsWith("TZ")); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java new file mode 100644 index 00000000000..9e7fac599be --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.complex.impl; + +import org.apache.arrow.vector.ExtensionTypeVector; +import org.apache.arrow.vector.types.pojo.Field; + +public class UnionExtensionWriter extends AbstractFieldWriter { + protected ExtensionTypeVector vector; + + public UnionExtensionWriter(ExtensionTypeVector vector) { + this.vector = vector; + } + + @Override + public void allocate() { + vector.allocateNew(); + } + + @Override + public void clear() { + vector.clear(); + } + + @Override + public int getValueCapacity() { + return vector.getValueCapacity(); + } + + @Override + public Field getField() { + return vector.getField(); + } + + @Override + public void close() throws Exception { + vector.close(); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/writer/FieldWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/writer/FieldWriter.java index a3cb7108a11..f4392635892 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/writer/FieldWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/writer/FieldWriter.java @@ -17,6 +17,7 @@ package org.apache.arrow.vector.complex.writer; +import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.ScalarWriter; @@ -26,7 +27,7 @@ * Composite of all writer types. Writers are convenience classes for incrementally * adding values to {@linkplain org.apache.arrow.vector.ValueVector}s. */ -public interface FieldWriter extends StructWriter, ListWriter, MapWriter, ScalarWriter { +public interface FieldWriter extends StructWriter, ListWriter, MapWriter, ScalarWriter, ExtensionWriter { void allocate(); void clear(); From 7cb7fd526a85350dee237067cb660f2e64eb1dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Finn=20V=C3=B6lkel?= Date: Fri, 17 May 2024 22:22:16 +0200 Subject: [PATCH 2/2] Adding test for #41728 --- .../apache/arrow/vector/TestStructVector.java | 47 +++++++++- .../vector/types/pojo/TestExtensionType.java | 90 +++++++++++++------ 2 files changed, 110 insertions(+), 27 deletions(-) diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java index 68f5e14dabb..10c47d06b16 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java @@ -17,12 +17,17 @@ package org.apache.arrow.vector; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.AbstractStructVector; @@ -35,9 +40,11 @@ import org.apache.arrow.vector.holders.ComplexHolder; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.Struct; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.TestExtensionType; import org.apache.arrow.vector.util.TransferPair; import org.junit.After; import org.junit.Assert; @@ -82,6 +89,44 @@ public void testMakeTransferPair() { } } + @Test + public void testStructVectorWithExtensionTypes() { + TestExtensionType.UuidType uuidType = new TestExtensionType.UuidType(); + Field uuidField = new Field("struct_child", FieldType.nullable(uuidType), null); + Field structField = new Field("struct", FieldType.nullable(new ArrowType.Struct()), List.of(uuidField)); + StructVector s1 = new StructVector(structField, allocator, null); + StructVector s2 = (StructVector) structField.createVector(allocator); + s1.close(); + s2.close(); + } + + @Test + public void testStructVectorTransferPairWithExtensionType() { + TestExtensionType.UuidType uuidType = new TestExtensionType.UuidType(); + Field uuidField = new Field("uuid_child", FieldType.nullable(uuidType), null); + Field structField = new Field("struct", FieldType.nullable(new ArrowType.Struct()), List.of(uuidField)); + + StructVector s1 = (StructVector) structField.createVector(allocator); + TestExtensionType.UuidVector uuidVector = + s1.addOrGet("uuid_child", FieldType.nullable(uuidType), TestExtensionType.UuidVector.class); + s1.setValueCount(1); + uuidVector.set(0, new UUID(1, 2)); + s1.setIndexDefined(0); + + TransferPair tp = s1.getTransferPair(structField, allocator); + final StructVector toVector = (StructVector) tp.getTo(); + assertEquals(s1.getField(), toVector.getField()); + assertEquals(s1.getField().getChildren().get(0), toVector.getField().getChildren().get(0)); + // also fails but probably another issue + // assertEquals(s1.getValueCount(), toVector.getValueCount()); + // assertEquals(s1, toVector); + + s1.close(); + toVector.close(); + } + + + @Test public void testAllocateAfterReAlloc() throws Exception { Map metadata = new HashMap<>(); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index 872b2f3934b..46fddbbc6ba 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -32,6 +32,7 @@ import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.Collections; +import java.util.Map; import java.util.UUID; import org.apache.arrow.memory.BufferAllocator; @@ -41,6 +42,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; @@ -49,6 +51,7 @@ import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; +import org.apache.arrow.vector.util.TransferPair; import org.apache.arrow.vector.util.VectorBatchAppender; import org.apache.arrow.vector.validate.ValidateVectorVisitor; import org.junit.Assert; @@ -85,21 +88,21 @@ public void roundtripUuid() throws IOException { final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { reader.loadNextBatch(); final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - Assert.assertEquals(root.getSchema(), readerRoot.getSchema()); + assertEquals(root.getSchema(), readerRoot.getSchema()); final Field field = readerRoot.getSchema().getFields().get(0); final UuidType expectedType = new UuidType(); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0); - Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { - Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + assertEquals(vector.isNull(i), deserialized.isNull(i)); if (!vector.isNull(i)) { - Assert.assertEquals(vector.getObject(i), deserialized.getObject(i)); + assertEquals(vector.getObject(i), deserialized.getObject(i)); } } } @@ -138,23 +141,23 @@ public void readUnderlyingType() throws IOException { final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { reader.loadNextBatch(); final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - Assert.assertEquals(1, readerRoot.getSchema().getFields().size()); - Assert.assertEquals("a", readerRoot.getSchema().getFields().get(0).getName()); - Assert.assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary); - Assert.assertEquals(16, + assertEquals(1, readerRoot.getSchema().getFields().size()); + assertEquals("a", readerRoot.getSchema().getFields().get(0).getName()); + assertTrue(readerRoot.getSchema().getFields().get(0).getType() instanceof ArrowType.FixedSizeBinary); + assertEquals(16, ((ArrowType.FixedSizeBinary) readerRoot.getSchema().getFields().get(0).getType()).getByteWidth()); final Field field = readerRoot.getSchema().getFields().get(0); final UuidType expectedType = new UuidType(); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); final FixedSizeBinaryVector deserialized = (FixedSizeBinaryVector) readerRoot.getFieldVectors().get(0); - Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { - Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + assertEquals(vector.isNull(i), deserialized.isNull(i)); if (!vector.isNull(i)) { final UUID uuid = vector.getObject(i); final ByteBuffer bb = ByteBuffer.allocate(16); @@ -210,26 +213,26 @@ public void roundtripLocation() throws IOException { final ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { reader.loadNextBatch(); final VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); - Assert.assertEquals(root.getSchema(), readerRoot.getSchema()); + assertEquals(root.getSchema(), readerRoot.getSchema()); final Field field = readerRoot.getSchema().getFields().get(0); final LocationType expectedType = new LocationType(); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); - Assert.assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); final ExtensionTypeVector deserialized = (ExtensionTypeVector) readerRoot.getFieldVectors().get(0); - Assert.assertTrue(deserialized instanceof LocationVector); - Assert.assertEquals("location", deserialized.getName()); + assertTrue(deserialized instanceof LocationVector); + assertEquals("location", deserialized.getName()); StructVector deserStruct = (StructVector) deserialized.getUnderlyingVector(); Assert.assertNotNull(deserStruct.getChild("Latitude")); Assert.assertNotNull(deserStruct.getChild("Longitude")); - Assert.assertEquals(vector.getValueCount(), deserialized.getValueCount()); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { - Assert.assertEquals(vector.isNull(i), deserialized.isNull(i)); + assertEquals(vector.isNull(i), deserialized.isNull(i)); if (!vector.isNull(i)) { - Assert.assertEquals(vector.getObject(i), deserialized.getObject(i)); + assertEquals(vector.getObject(i), deserialized.getObject(i)); } } } @@ -278,11 +281,11 @@ public void testVectorCompare() { } } - static class UuidType extends ExtensionType { + public static class UuidType extends ExtensionType { @Override public ArrowType storageType() { - return new ArrowType.FixedSizeBinary(16); + return new FixedSizeBinary(16); } @Override @@ -314,10 +317,17 @@ public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocato } } - static class UuidVector extends ExtensionTypeVector { + public static class UuidVector extends ExtensionTypeVector { + private final Field field; public UuidVector(String name, BufferAllocator allocator, FixedSizeBinaryVector underlyingVector) { super(name, allocator, underlyingVector); + this.field = new Field(name, FieldType.nullable(new UuidType()), null); + } + + @Override + public Field getField() { + return field; } @Override @@ -342,6 +352,34 @@ public void set(int index, UUID uuid) { bb.putLong(uuid.getLeastSignificantBits()); getUnderlyingVector().set(index, bb.array()); } + + @Override + public TransferPair makeTransferPair(ValueVector to) { + ValueVector targetUnderlyingVector = ((UuidVector) to).getUnderlyingVector(); + TransferPair tp = getUnderlyingVector().makeTransferPair(targetUnderlyingVector); + + return new TransferPair() { + @Override + public void transfer() { + tp.transfer(); + } + + @Override + public void splitAndTransfer(int startIndex, int length) { + tp.splitAndTransfer(startIndex, length); + } + + @Override + public ValueVector getTo() { + return to; + } + + @Override + public void copyValueSafe(int fromIndex, int toIndex) { + tp.copyValueSafe(fromIndex, toIndex); + } + }; + } } static class LocationType extends ExtensionType { @@ -407,7 +445,7 @@ public int hashCode(int index, ArrowBufHasher hasher) { } @Override - public java.util.Map getObject(int index) { + public Map getObject(int index) { return getUnderlyingVector().getObject(index); }