diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseBinaryVector.java new file mode 100644 index 0000000000000..9ddbd10970fd9 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseBinaryVector.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import org.apache.arrow.vector.dictionary.ByteArrayWrapper; + +/** + * Interface for VarBinaryVector and FixedSizeBinaryVector. + */ +public interface BaseBinaryVector { + + /** + * Get the {@link ByteArrayWrapper} which holds the byte array at specific index. + * @param index index of object to get + */ + ByteArrayWrapper getByteArrayWrapper(int index); +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java index 61bd57c135697..862336f90acd1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java @@ -22,6 +22,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.impl.FixedSizeBinaryReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.dictionary.ByteArrayWrapper; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; import org.apache.arrow.vector.holders.NullableFixedSizeBinaryHolder; import org.apache.arrow.vector.types.Types.MinorType; @@ -37,7 +38,7 @@ * binary values which could be null. A validity buffer (bit vector) is * maintained to track which elements in the vector are null. */ -public class FixedSizeBinaryVector extends BaseFixedWidthVector { +public class FixedSizeBinaryVector extends BaseFixedWidthVector implements BaseBinaryVector { private final int byteWidth; private final FieldReader reader; @@ -363,6 +364,15 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((FixedSizeBinaryVector) to); } + @Override + public ByteArrayWrapper getByteArrayWrapper(int index) { + if (isNull(index)) { + return null; + } else { + return new ByteArrayWrapper(getObject(index)); + } + } + private class TransferImpl implements TransferPair { FixedSizeBinaryVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java index bd76f3cc03ff7..c8232cfe4df0b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VarBinaryVector.java @@ -22,6 +22,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.impl.VarBinaryReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.dictionary.ByteArrayWrapper; import org.apache.arrow.vector.holders.NullableVarBinaryHolder; import org.apache.arrow.vector.holders.VarBinaryHolder; import org.apache.arrow.vector.types.Types.MinorType; @@ -34,7 +35,7 @@ * values which could be NULL. A validity buffer (bit vector) is maintained * to track which elements in the vector are null. */ -public class VarBinaryVector extends BaseVariableWidthVector { +public class VarBinaryVector extends BaseVariableWidthVector implements BaseBinaryVector { private final FieldReader reader; /** @@ -279,6 +280,15 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((VarBinaryVector) to); } + @Override + public ByteArrayWrapper getByteArrayWrapper(int index) { + if (isNull(index)) { + return null; + } else { + return new ByteArrayWrapper(getObject(index)); + } + } + private class TransferImpl implements TransferPair { VarBinaryVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ByteArrayWrapper.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ByteArrayWrapper.java new file mode 100644 index 0000000000000..bcfac3983f331 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/ByteArrayWrapper.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.dictionary; + +import java.util.Arrays; + +/** + * Wrapper class for byte array. + */ +public class ByteArrayWrapper { + private final byte[] data; + + /** + * Constructs a new instance. + */ + public ByteArrayWrapper(byte[] data) { + if (data == null) { + throw new NullPointerException(); + } + + this.data = data; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ByteArrayWrapper)) { + return false; + } + + return Arrays.equals(data, ((ByteArrayWrapper)other).data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java index a28ea5b954d05..2655718164481 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java @@ -17,6 +17,7 @@ package org.apache.arrow.vector.dictionary; +import org.apache.arrow.vector.BaseBinaryVector; import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; @@ -45,9 +46,13 @@ public static ValueVector encode(ValueVector vector, Dictionary dictionary) { validateType(vector.getMinorType()); // load dictionary values into a hashmap for lookup DictionaryEncodeHashMap lookUps = new DictionaryEncodeHashMap<>(dictionary.getVector().getValueCount()); + + boolean binaryType = isBinaryType(vector.getMinorType()); + for (int i = 0; i < dictionary.getVector().getValueCount(); i++) { - // for primitive array types we need a wrapper that implements equals and hashcode appropriately - lookUps.put(dictionary.getVector().getObject(i), i); + Object key = binaryType ? ((BaseBinaryVector) dictionary.getVector()).getByteArrayWrapper(i) : + dictionary.getVector().getObject(i); + lookUps.put(key, i); } Field valueField = vector.getField(); @@ -68,7 +73,7 @@ public static ValueVector encode(ValueVector vector, Dictionary dictionary) { int count = vector.getValueCount(); for (int i = 0; i < count; i++) { - Object value = vector.getObject(i); + Object value = binaryType ? ((BaseBinaryVector) vector).getByteArrayWrapper(i) : vector.getObject(i); if (value != null) { // if it's null leave it null // note: this may fail if value was not included in the dictionary int encoded = lookUps.get(value); @@ -114,11 +119,18 @@ public static ValueVector decode(ValueVector indices, Dictionary dictionary) { return decoded; } - private static void validateType(MinorType type) { - // byte arrays don't work as keys in our dictionary map - we could wrap them with something to - // implement equals and hashcode if we want that functionality - if (type == MinorType.VARBINARY || type == MinorType.FIXEDSIZEBINARY || type == MinorType.UNION) { - throw new IllegalArgumentException("Dictionary encoding for complex types not implemented: type " + type); + private static boolean isBinaryType(MinorType type) { + if (type == MinorType.VARBINARY || type == MinorType.FIXEDSIZEBINARY) { + return true; } + return false; + } + + private static void validateType(MinorType type) { + if (type == MinorType.UNION) { + throw new IllegalArgumentException( + "Dictionary encoding for complex types not implemented: type " + type); + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index 36e763c9e74ed..0d2bce9f3f163 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -17,10 +17,13 @@ package org.apache.arrow.vector; +import static org.apache.arrow.vector.TestUtils.newVarBinaryVector; import static org.apache.arrow.vector.TestUtils.newVarCharVector; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.complex.ListVector; @@ -278,4 +281,51 @@ public void testEncodeStruct() { } } + @Test + public void testEncodeBinaryVector() { + // Create a new value vector + try (final VarBinaryVector vector = newVarBinaryVector("foo", allocator); + final VarBinaryVector dictionaryVector = newVarBinaryVector("dict", allocator);) { + vector.allocateNew(512, 5); + + // set some values + vector.setSafe(0, zero, 0, zero.length); + vector.setSafe(1, one, 0, one.length); + vector.setSafe(2, one, 0, one.length); + vector.setSafe(3, two, 0, two.length); + vector.setSafe(4, zero, 0, zero.length); + vector.setValueCount(5); + + // set some dictionary values + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, zero.length); + dictionaryVector.setSafe(1, one, 0, one.length); + dictionaryVector.setSafe(2, two, 0, two.length); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + try (final ValueVector encoded = DictionaryEncoder.encode(vector, dictionary)) { + // verify indices + assertEquals(IntVector.class, encoded.getClass()); + + IntVector index = ((IntVector)encoded); + assertEquals(5, index.getValueCount()); + assertEquals(0, index.get(0)); + assertEquals(1, index.get(1)); + assertEquals(1, index.get(2)); + assertEquals(2, index.get(3)); + assertEquals(0, index.get(4)); + + // now run through the decoder and verify we get the original back + try (VarBinaryVector decoded = (VarBinaryVector) DictionaryEncoder.decode(encoded, dictionary)) { + assertEquals(vector.getClass(), decoded.getClass()); + assertEquals(vector.getValueCount(), decoded.getValueCount()); + for (int i = 0; i < 5; i++) { + assertTrue(Arrays.equals(vector.getObject(i), decoded.getObject(i))); + } + } + } + } + } }