From 549f52edfcc6bda2ac9f8a4a3226a8ccd532fe13 Mon Sep 17 00:00:00 2001 From: Scott Cowell Date: Wed, 10 Jan 2024 15:56:08 -0800 Subject: [PATCH 1/4] DX-85876: Failure in UnionReader.read after DecimalVector promotion to UnionVector (#61) When a DecimalVector is promoted to a UnionVector via a PromotableWriter, the UnionVector will have the decimal vector in it's internal struct vector, but the decimalVector field will not be set. If UnionReader.read is then used to read from the UnionVector, it will fail when it tries to read one of the promoted decimal values, due to decimalVector being null, and the exact decimal type not being provided. This failure is unnecessary though as we have a pre-existing decimal vector, the caller just does not know the exact type - and it shouldn't be required to. The change here is to check for a pre-existing decimal vector in the internal struct when getDecimalVector() is called. If one exists, set the decimalVector field and return. Otherwise, if none exists, throw the exception. --- .../main/codegen/templates/UnionVector.java | 4 ++ .../complex/impl/TestPromotableWriter.java | 43 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/vector/src/main/codegen/templates/UnionVector.java b/vector/src/main/codegen/templates/UnionVector.java index e0fd0e4644..aaf28c0d5d 100644 --- a/vector/src/main/codegen/templates/UnionVector.java +++ b/vector/src/main/codegen/templates/UnionVector.java @@ -23,6 +23,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseValueVector; import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.complex.AbstractStructVector; @@ -306,6 +307,9 @@ public StructVector getStruct() { public ${name}Vector get${name}Vector(String name) { if (${uncappedName}Vector == null) { + ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); int vectorCount = internalStruct.size(); ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, ${name}Vector.class); if (internalStruct.size() > vectorCount) { diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 19b26b6d0e..d18d1b0549 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -21,12 +21,15 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.util.Objects; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; import org.apache.arrow.vector.LargeVarBinaryVector; import org.apache.arrow.vector.LargeVarCharVector; @@ -39,14 +42,18 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.NullableDecimalHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.UnionHolder; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.Text; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -728,5 +735,41 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception { assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } + + @Test + public void testPromoteToUnionFromDecimal() throws Exception { + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final DecimalVector v = container.addOrGet("dec", + FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + container.setValueCount(1); + + writer.setPosition(0); + writer.writeDecimal(new BigDecimal("0.1")); + writer.setPosition(1); + writer.writeInt(1); + + container.setValueCount(3); + + UnionVector unionVector = (UnionVector) container.getChild("dec"); + UnionHolder holder = new UnionHolder(); + + unionVector.get(0, holder); + NullableDecimalHolder decimalHolder = new NullableDecimalHolder(); + holder.reader.read(decimalHolder); + + assertEquals(1, decimalHolder.isSet); + assertEquals(new BigDecimal("0.1"), + DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128)); + + unionVector.get(1, holder); + NullableIntHolder intHolder = new NullableIntHolder(); + holder.reader.read(intHolder); + + assertEquals(1, intHolder.isSet); + assertEquals(1, intHolder.value); + } } } From 69d8bd0f19879bcfacdaeb4315ca1d7c13594c88 Mon Sep 17 00:00:00 2001 From: Logan Riggs Date: Thu, 14 Mar 2024 12:54:53 -0700 Subject: [PATCH 2/4] Fix jar build --- vector/src/main/codegen/templates/UnionVector.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vector/src/main/codegen/templates/UnionVector.java b/vector/src/main/codegen/templates/UnionVector.java index aaf28c0d5d..67efdf60f7 100644 --- a/vector/src/main/codegen/templates/UnionVector.java +++ b/vector/src/main/codegen/templates/UnionVector.java @@ -280,7 +280,10 @@ public StructVector getStruct() { <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> public ${name}Vector get${name}Vector() { if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + } } return ${uncappedName}Vector; } @@ -307,9 +310,6 @@ public StructVector getStruct() { public ${name}Vector get${name}Vector(String name) { if (${uncappedName}Vector == null) { - ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); - if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); int vectorCount = internalStruct.size(); ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, ${name}Vector.class); if (internalStruct.size() > vectorCount) { From 7d5e3ced50aeaf47f38a694bb4628b4c4d8f22e7 Mon Sep 17 00:00:00 2001 From: Logan Riggs Date: Wed, 12 Mar 2025 14:03:56 -0700 Subject: [PATCH 3/4] Fix typo. --- .../apache/arrow/vector/complex/impl/TestPromotableWriter.java | 1 + 1 file changed, 1 insertion(+) diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index d18d1b0549..46ffc07340 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -735,6 +735,7 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception { assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8)); assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8)); } + } @Test public void testPromoteToUnionFromDecimal() throws Exception { From 766ba3ab0af3b91fdf48b173ac7625d447896c81 Mon Sep 17 00:00:00 2001 From: Logan Riggs Date: Wed, 12 Mar 2025 16:26:23 -0700 Subject: [PATCH 4/4] Run spotless --- .../complex/impl/TestPromotableWriter.java | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 46ffc07340..a791e55135 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -21,7 +21,6 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; - import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -739,10 +738,12 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception { @Test public void testPromoteToUnionFromDecimal() throws Exception { - try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); - final DecimalVector v = container.addOrGet("dec", - FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); - final PromotableWriter writer = new PromotableWriter(v, container)) { + try (final NonNullableStructVector container = + NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final DecimalVector v = + container.addOrGet( + "dec", FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); + final PromotableWriter writer = new PromotableWriter(v, container)) { container.allocateNew(); container.setValueCount(1); @@ -762,8 +763,10 @@ public void testPromoteToUnionFromDecimal() throws Exception { holder.reader.read(decimalHolder); assertEquals(1, decimalHolder.isSet); - assertEquals(new BigDecimal("0.1"), - DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128)); + assertEquals( + new BigDecimal("0.1"), + DecimalUtility.getBigDecimalFromArrowBuf( + decimalHolder.buffer, 0, decimalHolder.scale, 128)); unionVector.get(1, holder); NullableIntHolder intHolder = new NullableIntHolder();