Skip to content

Commit 549f52e

Browse files
sgcowelllriggs
authored andcommitted
DX-85876: Failure in UnionReader.read after DecimalVector promotion to UnionVector (#61)
When a DecimalVector is promoted to a UnionVector via a PromotableWriter, the UnionVector will have the decimal vector in it's internal struct vector, but the decimalVector field will not be set. If UnionReader.read is then used to read from the UnionVector, it will fail when it tries to read one of the promoted decimal values, due to decimalVector being null, and the exact decimal type not being provided. This failure is unnecessary though as we have a pre-existing decimal vector, the caller just does not know the exact type - and it shouldn't be required to. The change here is to check for a pre-existing decimal vector in the internal struct when getDecimalVector() is called. If one exists, set the decimalVector field and return. Otherwise, if none exists, throw the exception.
1 parent d304da5 commit 549f52e

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

vector/src/main/codegen/templates/UnionVector.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.arrow.util.Preconditions;
2424
import org.apache.arrow.vector.BaseValueVector;
2525
import org.apache.arrow.vector.BitVectorHelper;
26+
import org.apache.arrow.vector.DecimalVector;
2627
import org.apache.arrow.vector.FieldVector;
2728
import org.apache.arrow.vector.ValueVector;
2829
import org.apache.arrow.vector.complex.AbstractStructVector;
@@ -306,6 +307,9 @@ public StructVector getStruct() {
306307
307308
public ${name}Vector get${name}Vector(String name) {
308309
if (${uncappedName}Vector == null) {
310+
${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class);
311+
if (${uncappedName}Vector == null) {
312+
throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector");
309313
int vectorCount = internalStruct.size();
310314
${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, ${name}Vector.class);
311315
if (internalStruct.size() > vectorCount) {

vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
import static org.junit.jupiter.api.Assertions.assertNull;
2222
import static org.junit.jupiter.api.Assertions.assertThrows;
2323

24+
25+
import java.math.BigDecimal;
2426
import java.nio.ByteBuffer;
2527
import java.nio.ByteOrder;
2628
import java.nio.charset.StandardCharsets;
2729
import java.util.Objects;
2830
import org.apache.arrow.memory.ArrowBuf;
2931
import org.apache.arrow.memory.BufferAllocator;
32+
import org.apache.arrow.vector.DecimalVector;
3033
import org.apache.arrow.vector.DirtyRootAllocator;
3134
import org.apache.arrow.vector.LargeVarBinaryVector;
3235
import org.apache.arrow.vector.LargeVarCharVector;
@@ -39,14 +42,18 @@
3942
import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter;
4043
import org.apache.arrow.vector.holders.DurationHolder;
4144
import org.apache.arrow.vector.holders.FixedSizeBinaryHolder;
45+
import org.apache.arrow.vector.holders.NullableDecimalHolder;
46+
import org.apache.arrow.vector.holders.NullableIntHolder;
4247
import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder;
4348
import org.apache.arrow.vector.holders.TimeStampMilliTZHolder;
49+
import org.apache.arrow.vector.holders.UnionHolder;
4450
import org.apache.arrow.vector.types.TimeUnit;
4551
import org.apache.arrow.vector.types.Types;
4652
import org.apache.arrow.vector.types.pojo.ArrowType;
4753
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID;
4854
import org.apache.arrow.vector.types.pojo.Field;
4955
import org.apache.arrow.vector.types.pojo.FieldType;
56+
import org.apache.arrow.vector.util.DecimalUtility;
5057
import org.apache.arrow.vector.util.Text;
5158
import org.junit.jupiter.api.AfterEach;
5259
import org.junit.jupiter.api.BeforeEach;
@@ -728,5 +735,41 @@ public void testPromoteLargeVarBinaryHelpersDirect() throws Exception {
728735
assertEquals("row3", new String(Objects.requireNonNull(uv.get(2)), StandardCharsets.UTF_8));
729736
assertEquals("row4", new String(Objects.requireNonNull(uv.get(3)), StandardCharsets.UTF_8));
730737
}
738+
739+
@Test
740+
public void testPromoteToUnionFromDecimal() throws Exception {
741+
try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator);
742+
final DecimalVector v = container.addOrGet("dec",
743+
FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class);
744+
final PromotableWriter writer = new PromotableWriter(v, container)) {
745+
746+
container.allocateNew();
747+
container.setValueCount(1);
748+
749+
writer.setPosition(0);
750+
writer.writeDecimal(new BigDecimal("0.1"));
751+
writer.setPosition(1);
752+
writer.writeInt(1);
753+
754+
container.setValueCount(3);
755+
756+
UnionVector unionVector = (UnionVector) container.getChild("dec");
757+
UnionHolder holder = new UnionHolder();
758+
759+
unionVector.get(0, holder);
760+
NullableDecimalHolder decimalHolder = new NullableDecimalHolder();
761+
holder.reader.read(decimalHolder);
762+
763+
assertEquals(1, decimalHolder.isSet);
764+
assertEquals(new BigDecimal("0.1"),
765+
DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128));
766+
767+
unionVector.get(1, holder);
768+
NullableIntHolder intHolder = new NullableIntHolder();
769+
holder.reader.read(intHolder);
770+
771+
assertEquals(1, intHolder.isSet);
772+
assertEquals(1, intHolder.value);
773+
}
731774
}
732775
}

0 commit comments

Comments
 (0)