From 2d3f3df930c1245d1654f9d8506dfb306076be52 Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 29 Sep 2023 06:11:12 -0700 Subject: [PATCH] GH-37702: [Java] Add vector validation consistent with C++ (#37942) ### Rationale for this change Make vector validation code more consistent with C++. Add missing checks and have the entry point be the same so that the code is easier to read/write when working with both languages. ### What changes are included in this PR? Make vector validation more consistent with Array::Validate() in C++: * Add validate() and validateFull() instance methods to vectors. * Validate that VarCharVector and LargeVarCharVector contents are valid UTF-8. * Validate that DecimalVector and Decimal256Vector contents fit within the supplied precision and scale. * Validate that NullVectors contain only nulls. * Validate that FixedSizeBinaryVector values have the correct length. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * Closes: #37702 Authored-by: James Duong Signed-off-by: David Li --- .../arrow/vector/BaseFixedWidthVector.java | 7 ++ .../vector/BaseLargeVariableWidthVector.java | 7 ++ .../arrow/vector/BaseVariableWidthVector.java | 7 ++ .../apache/arrow/vector/Decimal256Vector.java | 13 ++++ .../apache/arrow/vector/DecimalVector.java | 13 ++++ .../arrow/vector/FixedSizeBinaryVector.java | 13 ++++ .../arrow/vector/LargeVarCharVector.java | 12 +++ .../org/apache/arrow/vector/ValueVector.java | 9 +++ .../apache/arrow/vector/VarCharVector.java | 12 +++ .../arrow/vector/util/DecimalUtility.java | 10 ++- .../org/apache/arrow/vector/util/Text.java | 50 ++++++++++--- .../validate/ValidateVectorDataVisitor.java | 5 ++ .../vector/validate/TestValidateVector.java | 14 ++++ .../validate/TestValidateVectorFull.java | 74 +++++++++++++++++++ 14 files changed, 233 insertions(+), 13 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java index 223ae9aa8cb1c..04a038a0b5dfd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java @@ -550,6 +550,13 @@ private void setReaderAndWriterIndex() { } } + /** + * Validate the scalar values held by this vector. + */ + public void validateScalars() { + // No validation by default. + } + /** * Construct a transfer pair of this vector and another vector of same type. * @param ref name of the target vector diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java index 90694db830cd6..4d5a8a5119c53 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseLargeVariableWidthVector.java @@ -643,6 +643,13 @@ public ArrowBuf[] getBuffers(boolean clear) { return buffers; } + /** + * Validate the scalar values held by this vector. + */ + public void validateScalars() { + // No validation by default. + } + /** * Construct a transfer pair of this vector and another vector of same type. * @param ref name of the target vector diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index 2a89590bf8440..d7f5ff05a935d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -685,6 +685,13 @@ public ArrowBuf[] getBuffers(boolean clear) { return buffers; } + /** + * Validate the scalar values held by this vector. + */ + public void validateScalars() { + // No validation by default. + } + /** * Construct a transfer pair of this vector and another vector of same type. * @param ref name of the target vector diff --git a/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java b/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java index 70a895ff40496..79a9badc3955d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/Decimal256Vector.java @@ -35,6 +35,7 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.validate.ValidateUtil; /** @@ -527,6 +528,18 @@ public void setSafe(int index, int isSet, long start, ArrowBuf buffer) { set(index, isSet, start, buffer); } + @Override + public void validateScalars() { + for (int i = 0; i < getValueCount(); ++i) { + BigDecimal value = getObject(i); + if (value != null) { + ValidateUtil.validateOrThrow(DecimalUtility.checkPrecisionAndScaleNoThrow(value, getPrecision(), getScale()), + "Invalid value for Decimal256Vector at position " + i + ". Value does not fit in precision " + + getPrecision() + " and scale " + getScale() + "."); + } + } + } + /*----------------------------------------------------------------* | | | vector transfer | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 6a3ec60afc52e..d1a3bfc3afb10 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -35,6 +35,7 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.validate.ValidateUtil; /** * DecimalVector implements a fixed width vector (16 bytes) of @@ -526,6 +527,18 @@ public void setSafe(int index, int isSet, long start, ArrowBuf buffer) { set(index, isSet, start, buffer); } + @Override + public void validateScalars() { + for (int i = 0; i < getValueCount(); ++i) { + BigDecimal value = getObject(i); + if (value != null) { + ValidateUtil.validateOrThrow(DecimalUtility.checkPrecisionAndScaleNoThrow(value, getPrecision(), getScale()), + "Invalid value for DecimalVector at position " + i + ". Value does not fit in precision " + + getPrecision() + " and scale " + getScale() + "."); + } + } + } + /*----------------------------------------------------------------* | | | vector transfer | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java index 3ce2bb77ccc55..967d560d78dea 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java @@ -31,6 +31,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.validate.ValidateUtil; /** * FixedSizeBinaryVector implements a fixed width vector of @@ -320,6 +321,18 @@ public static byte[] get(final ArrowBuf buffer, final int index, final int byteW return dst; } + @Override + public void validateScalars() { + for (int i = 0; i < getValueCount(); ++i) { + byte[] value = get(i); + if (value != null) { + ValidateUtil.validateOrThrow(value.length == byteWidth, + "Invalid value for FixedSizeBinaryVector at position " + i + ". The length was " + + value.length + " but the length of each element should be " + byteWidth + "."); + } + } + } + /*----------------------------------------------------------------* | | | vector transfer | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java b/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java index 1f8d9b7d3a85c..e9472c9f2c71e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/LargeVarCharVector.java @@ -27,6 +27,7 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.Text; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.validate.ValidateUtil; /** * LargeVarCharVector implements a variable width vector of VARCHAR @@ -261,6 +262,17 @@ public void setSafe(int index, Text text) { setSafe(index, text.getBytes(), 0, text.getLength()); } + @Override + public void validateScalars() { + for (int i = 0; i < getValueCount(); ++i) { + byte[] value = get(i); + if (value != null) { + ValidateUtil.validateOrThrow(Text.validateUTF8NoThrow(value), + "Non-UTF-8 data in VarCharVector at position " + i + "."); + } + } + } + /*----------------------------------------------------------------* | | | vector transfer | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java index aa29c29314e33..462b512c65436 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.util.ValueVectorUtility; /** * An abstraction that is used to store a sequence of values in an individual column. @@ -282,4 +283,12 @@ public interface ValueVector extends Closeable, Iterable { * @return the name of the vector. */ String getName(); + + default void validate() { + ValueVectorUtility.validate(this); + } + + default void validateFull() { + ValueVectorUtility.validateFull(this); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java b/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java index bc5c68b29f310..2c83893819a1e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VarCharVector.java @@ -29,6 +29,7 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.Text; import org.apache.arrow.vector.util.TransferPair; +import org.apache.arrow.vector.validate.ValidateUtil; /** * VarCharVector implements a variable width vector of VARCHAR @@ -261,6 +262,17 @@ public void setSafe(int index, Text text) { setSafe(index, text.getBytes(), 0, text.getLength()); } + @Override + public void validateScalars() { + for (int i = 0; i < getValueCount(); ++i) { + byte[] value = get(i); + if (value != null) { + ValidateUtil.validateOrThrow(Text.validateUTF8NoThrow(value), + "Non-UTF-8 data in VarCharVector at position " + i + "."); + } + } + } + /*----------------------------------------------------------------* | | | vector transfer | diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index 137ac746f4aee..a81169b8f7d73 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -95,11 +95,19 @@ public static boolean checkPrecisionAndScale(BigDecimal value, int vectorPrecisi } if (value.precision() > vectorPrecision) { throw new UnsupportedOperationException("BigDecimal precision can not be greater than that in the Arrow " + - "vector: " + value.precision() + " > " + vectorPrecision); + "vector: " + value.precision() + " > " + vectorPrecision); } return true; } + /** + * Check that the BigDecimal scale equals the vectorScale and that the BigDecimal precision is + * less than or equal to the vectorPrecision. Return true if so, otherwise return false. + */ + public static boolean checkPrecisionAndScaleNoThrow(BigDecimal value, int vectorPrecision, int vectorScale) { + return value.scale() == vectorScale && value.precision() < vectorPrecision; + } + /** * Check that the decimal scale equals the vectorScale and that the decimal precision is * less than or equal to the vectorPrecision. If not, then an UnsupportedOperationException is diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java b/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java index b479305c6e39b..778af0ca956df 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/Text.java @@ -30,6 +30,7 @@ import java.text.CharacterIterator; import java.text.StringCharacterIterator; import java.util.Arrays; +import java.util.Optional; import com.fasterxml.jackson.core.JsonGenerationException; import com.fasterxml.jackson.core.JsonGenerator; @@ -466,6 +467,16 @@ public static ByteBuffer encode(String string, boolean replace) private static final int TRAIL_BYTE = 2; + /** + * Check if a byte array contains valid utf-8. + * + * @param utf8 byte array + * @return true if the input is valid UTF-8. False otherwise. + */ + public static boolean validateUTF8NoThrow(byte[] utf8) { + return !validateUTF8Internal(utf8, 0, utf8.length).isPresent(); + } + /** * Check if a byte array contains valid utf-8. * @@ -484,8 +495,22 @@ public static void validateUTF8(byte[] utf8) throws MalformedInputException { * @param len the length of the byte sequence * @throws MalformedInputException if the byte array contains invalid bytes */ - public static void validateUTF8(byte[] utf8, int start, int len) - throws MalformedInputException { + public static void validateUTF8(byte[] utf8, int start, int len) throws MalformedInputException { + Optional result = validateUTF8Internal(utf8, start, len); + if (result.isPresent()) { + throw new MalformedInputException(result.get()); + } + } + + /** + * Check to see if a byte array is valid utf-8. + * + * @param utf8 the array of bytes + * @param start the offset of the first byte in the array + * @param len the length of the byte sequence + * @return the position where a malformed byte occurred or Optional.empty() if the byte array was valid UTF-8. + */ + private static Optional validateUTF8Internal(byte[] utf8, int start, int len) { int count = start; int leadByte = 0; int length = 0; @@ -501,51 +526,51 @@ public static void validateUTF8(byte[] utf8, int start, int len) switch (length) { case 0: // check for ASCII if (leadByte > 0x7F) { - throw new MalformedInputException(count); + return Optional.of(count); } break; case 1: if (leadByte < 0xC2 || leadByte > 0xDF) { - throw new MalformedInputException(count); + return Optional.of(count); } state = TRAIL_BYTE_1; break; case 2: if (leadByte < 0xE0 || leadByte > 0xEF) { - throw new MalformedInputException(count); + return Optional.of(count); } state = TRAIL_BYTE_1; break; case 3: if (leadByte < 0xF0 || leadByte > 0xF4) { - throw new MalformedInputException(count); + return Optional.of(count); } state = TRAIL_BYTE_1; break; default: // too long! Longest valid UTF-8 is 4 bytes (lead + three) // or if < 0 we got a trail byte in the lead byte position - throw new MalformedInputException(count); + return Optional.of(count); } // switch (length) break; case TRAIL_BYTE_1: if (leadByte == 0xF0 && aByte < 0x90) { - throw new MalformedInputException(count); + return Optional.of(count); } if (leadByte == 0xF4 && aByte > 0x8F) { - throw new MalformedInputException(count); + return Optional.of(count); } if (leadByte == 0xE0 && aByte < 0xA0) { - throw new MalformedInputException(count); + return Optional.of(count); } if (leadByte == 0xED && aByte > 0x9F) { - throw new MalformedInputException(count); + return Optional.of(count); } // falls through to regular trail-byte test!! case TRAIL_BYTE: if (aByte < 0x80 || aByte > 0xBF) { - throw new MalformedInputException(count); + return Optional.of(count); } if (--length == 0) { state = LEAD_BYTE; @@ -558,6 +583,7 @@ public static void validateUTF8(byte[] utf8, int start, int len) } // switch (state) count++; } + return Optional.empty(); } /** diff --git a/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorDataVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorDataVisitor.java index cdeb4f1eaa1ca..6d33be7a0dbac 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorDataVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/validate/ValidateVectorDataVisitor.java @@ -85,18 +85,21 @@ private void validateTypeBuffer(ArrowBuf typeBuf, int valueCount) { @Override public Void visit(BaseFixedWidthVector vector, Void value) { + vector.validateScalars(); return null; } @Override public Void visit(BaseVariableWidthVector vector, Void value) { validateOffsetBuffer(vector, vector.getValueCount()); + vector.validateScalars(); return null; } @Override public Void visit(BaseLargeVariableWidthVector vector, Void value) { validateLargeOffsetBuffer(vector, vector.getValueCount()); + vector.validateScalars(); return null; } @@ -169,6 +172,8 @@ public Void visit(DenseUnionVector vector, Void value) { @Override public Void visit(NullVector vector, Void value) { + ValidateUtil.validateOrThrow(vector.getNullCount() == vector.getValueCount(), + "NullVector should have only null entries."); return null; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVector.java b/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVector.java index 2354b281ed41d..20492036dab99 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVector.java @@ -251,6 +251,20 @@ public void testDenseUnionVector() { } } + @Test + public void testBaseFixedWidthVectorInstanceMethod() { + try (final IntVector vector = new IntVector("v", allocator)) { + vector.validate(); + setVector(vector, 1, 2, 3); + vector.validate(); + + vector.getDataBuffer().capacity(0); + ValidateUtil.ValidateException e = assertThrows(ValidateUtil.ValidateException.class, + () -> vector.validate()); + assertTrue(e.getMessage().contains("Not enough capacity for fixed width data buffer")); + } + } + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { writer.start(); writer.integer("f0").writeInt(value1); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVectorFull.java b/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVectorFull.java index 4241a0d9cff93..ca71a622bb8ea 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVectorFull.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/validate/TestValidateVectorFull.java @@ -23,11 +23,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.LargeVarCharVector; @@ -231,4 +234,75 @@ public void testDenseUnionVector() { assertTrue(e.getMessage().contains("Dense union vector offset exceeds sub-vector boundary")); } } + + @Test + public void testBaseVariableWidthVectorInstanceMethod() { + try (final VarCharVector vector = new VarCharVector("v", allocator)) { + vector.validateFull(); + setVector(vector, "aaa", "bbb", "ccc"); + vector.validateFull(); + + ArrowBuf offsetBuf = vector.getOffsetBuffer(); + offsetBuf.setInt(0, 100); + offsetBuf.setInt(4, 50); + + ValidateUtil.ValidateException e = assertThrows(ValidateUtil.ValidateException.class, + vector::validateFull); + assertTrue(e.getMessage().contains("The values in positions 0 and 1 of the offset buffer are decreasing")); + } + } + + @Test + public void testValidateVarCharUTF8() { + try (final VarCharVector vector = new VarCharVector("v", allocator)) { + vector.validateFull(); + setVector(vector, "aaa".getBytes(StandardCharsets.UTF_8), "bbb".getBytes(StandardCharsets.UTF_8), + new byte[] {(byte) 0xFF, (byte) 0xFE}); + ValidateUtil.ValidateException e = assertThrows(ValidateUtil.ValidateException.class, + vector::validateFull); + assertTrue(e.getMessage().contains("UTF")); + } + } + + @Test + public void testValidateLargeVarCharUTF8() { + try (final LargeVarCharVector vector = new LargeVarCharVector("v", allocator)) { + vector.validateFull(); + setVector(vector, "aaa".getBytes(StandardCharsets.UTF_8), "bbb".getBytes(StandardCharsets.UTF_8), + new byte[] {(byte) 0xFF, (byte) 0xFE}); + ValidateUtil.ValidateException e = assertThrows(ValidateUtil.ValidateException.class, + vector::validateFull); + assertTrue(e.getMessage().contains("UTF")); + } + } + + @Test + public void testValidateDecimal() { + try (final DecimalVector vector = new DecimalVector(Field.nullable("v", + new ArrowType.Decimal(2, 0, DecimalVector.TYPE_WIDTH * 8)), allocator)) { + vector.validateFull(); + setVector(vector, 1L); + vector.validateFull(); + vector.clear(); + setVector(vector, Long.MAX_VALUE); + ValidateUtil.ValidateException e = assertThrows(ValidateUtil.ValidateException.class, + vector::validateFull); + assertTrue(e.getMessage().contains("Decimal")); + } + } + + @Test + public void testValidateDecimal256() { + try (final Decimal256Vector vector = new Decimal256Vector(Field.nullable("v", + new ArrowType.Decimal(2, 0, DecimalVector.TYPE_WIDTH * 8)), allocator)) { + vector.validateFull(); + setVector(vector, 1L); + vector.validateFull(); + vector.clear(); + setVector(vector, Long.MAX_VALUE); + ValidateUtil.ValidateException e = assertThrows(ValidateUtil.ValidateException.class, + vector::validateFull); + assertTrue(e.getMessage().contains("Decimal")); + } + } }