From 715cb5767e18cee1a0e15a4f43916c9710c3633d Mon Sep 17 00:00:00 2001 From: liyafan82 Date: Thu, 15 Apr 2021 20:21:03 -0700 Subject: [PATCH] ARROW-11999: [Java] Support parallel vector element search with user-specified comparator This is in response to the discussion in https://github.com/apache/arrow/pull/5631#discussion_r339110228 Currently, we only support parallel search with `RangeEqualsVisitor`, which does not support user-specified comparators. We want to provide the functionality in this issue to support wider range of use cases. Closes #9736 from liyafan82/fly_0317_par Authored-by: liyafan82 Signed-off-by: Micah Kornfield --- .../algorithm/search/ParallelSearcher.java | 77 +++++++++++++++++-- .../search/TestParallelSearcher.java | 52 +++++++++++-- 2 files changed, 116 insertions(+), 13 deletions(-) diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java index 39678a17686d2..e93eb2c3dea6a 100644 --- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java +++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/search/ParallelSearcher.java @@ -66,8 +66,19 @@ public ParallelSearcher(V vector, ExecutorService threadPool, int numThreads) { this.numThreads = numThreads; } + private CompletableFuture[] initSearch() { + keyPosition = -1; + final CompletableFuture[] futures = new CompletableFuture[numThreads]; + for (int i = 0; i < futures.length; i++) { + futures[i] = new CompletableFuture<>(); + } + return futures; + } + /** - * Search for the key in the target vector. + * Search for the key in the target vector. The element-wise comparison is based on + * {@link RangeEqualsVisitor}, so there are two possible results for each element-wise + * comparison: equal and un-equal. * @param keyVector the vector containing the search key. * @param keyIndex the index of the search key in the key vector. * @return the position of a matched value in the target vector, @@ -80,13 +91,8 @@ public ParallelSearcher(V vector, ExecutorService threadPool, int numThreads) { * @throws InterruptedException if a thread is interrupted. */ public int search(V keyVector, int keyIndex) throws ExecutionException, InterruptedException { - keyPosition = -1; + final CompletableFuture[] futures = initSearch(); final int valueCount = vector.getValueCount(); - final CompletableFuture[] futures = new CompletableFuture[numThreads]; - for (int i = 0; i < futures.length; i++) { - futures[i] = new CompletableFuture<>(); - } - for (int i = 0; i < numThreads; i++) { final int tid = i; threadPool.submit(() -> { @@ -124,4 +130,61 @@ public int search(V keyVector, int keyIndex) throws ExecutionException, Interrup CompletableFuture.allOf(futures).get(); return keyPosition; } + + /** + * Search for the key in the target vector. The element-wise comparison is based on + * {@link VectorValueComparator}, so there are three possible results for each element-wise + * comparison: less than, equal to and greater than. + * @param keyVector the vector containing the search key. + * @param keyIndex the index of the search key in the key vector. + * @param comparator the comparator for comparing the key against vector elements. + * @return the position of a matched value in the target vector, + * or -1 if none is found. Please note that if there are multiple + * matches of the key in the target vector, this method makes no + * guarantees about which instance is returned. + * For an alternative search implementation that always finds the first match of the key, + * see {@link VectorSearcher#linearSearch(ValueVector, VectorValueComparator, ValueVector, int)}. + * @throws ExecutionException if an exception occurs in a thread. + * @throws InterruptedException if a thread is interrupted. + */ + public int search( + V keyVector, int keyIndex, VectorValueComparator comparator) throws ExecutionException, InterruptedException { + final CompletableFuture[] futures = initSearch(); + final int valueCount = vector.getValueCount(); + for (int i = 0; i < numThreads; i++) { + final int tid = i; + threadPool.submit(() -> { + // convert to long to avoid overflow + int start = (int) (((long) valueCount) * tid / numThreads); + int end = (int) ((long) valueCount) * (tid + 1) / numThreads; + + if (start >= end) { + // no data assigned to this task. + futures[tid].complete(false); + return; + } + + VectorValueComparator localComparator = comparator.createNew(); + localComparator.attachVectors(vector, keyVector); + for (int pos = start; pos < end; pos++) { + if (keyPosition != -1) { + // the key has been found by another task + futures[tid].complete(false); + return; + } + if (localComparator.compare(pos, keyIndex) == 0) { + keyPosition = pos; + futures[tid].complete(true); + return; + } + } + + // no match value is found. + futures[tid].complete(false); + }); + } + + CompletableFuture.allOf(futures).get(); + return keyPosition; + } } diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java index a01cc1af3bb29..767935aaa4bae 100644 --- a/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java +++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/search/TestParallelSearcher.java @@ -19,10 +19,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import org.apache.arrow.algorithm.sort.DefaultVectorComparators; +import org.apache.arrow.algorithm.sort.VectorValueComparator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; @@ -30,24 +35,51 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; /** * Test cases for {@link ParallelSearcher}. */ +@RunWith(Parameterized.class) public class TestParallelSearcher { - private static final int THREAD_COUNT = 10; + private enum ComparatorType { + EqualityComparator, + OrderingComparator; + } private static final int VECTOR_LENGTH = 10000; + private final int threadCount; + private BufferAllocator allocator; private ExecutorService threadPool; + private final ComparatorType comparatorType; + + public TestParallelSearcher(ComparatorType comparatorType, int threadCount) { + this.comparatorType = comparatorType; + this.threadCount = threadCount; + } + + @Parameterized.Parameters(name = "comparator type = {0}, thread count = {1}") + public static Collection getComparatorName() { + List params = new ArrayList<>(); + int[] threadCounts = {1, 2, 5, 10, 20, 50}; + for (ComparatorType type : ComparatorType.values()) { + for (int count : threadCounts) { + params.add(new Object[] {type, count}); + } + } + return params; + } + @Before public void prepare() { allocator = new RootAllocator(1024 * 1024); - threadPool = Executors.newFixedThreadPool(THREAD_COUNT); + threadPool = Executors.newFixedThreadPool(threadCount); } @After @@ -63,6 +95,10 @@ public void testParallelIntSearch() throws ExecutionException, InterruptedExcept targetVector.allocateNew(VECTOR_LENGTH); keyVector.allocateNew(VECTOR_LENGTH); + // if we are comparing elements using equality semantics, we do not need a comparator here. + VectorValueComparator comparator = comparatorType == ComparatorType.EqualityComparator ? null + : DefaultVectorComparators.createDefaultComparator(targetVector); + for (int i = 0; i < VECTOR_LENGTH; i++) { targetVector.set(i, i); keyVector.set(i, i * 2); @@ -70,9 +106,9 @@ public void testParallelIntSearch() throws ExecutionException, InterruptedExcept targetVector.setValueCount(VECTOR_LENGTH); keyVector.setValueCount(VECTOR_LENGTH); - ParallelSearcher searcher = new ParallelSearcher<>(targetVector, threadPool, THREAD_COUNT); + ParallelSearcher searcher = new ParallelSearcher<>(targetVector, threadPool, threadCount); for (int i = 0; i < VECTOR_LENGTH; i++) { - int pos = searcher.search(keyVector, i); + int pos = comparator == null ? searcher.search(keyVector, i) : searcher.search(keyVector, i, comparator); if (i * 2 < VECTOR_LENGTH) { assertEquals(i * 2, pos); } else { @@ -89,6 +125,10 @@ public void testParallelStringSearch() throws ExecutionException, InterruptedExc targetVector.allocateNew(VECTOR_LENGTH); keyVector.allocateNew(VECTOR_LENGTH); + // if we are comparing elements using equality semantics, we do not need a comparator here. + VectorValueComparator comparator = comparatorType == ComparatorType.EqualityComparator ? null + : DefaultVectorComparators.createDefaultComparator(targetVector); + for (int i = 0; i < VECTOR_LENGTH; i++) { targetVector.setSafe(i, String.valueOf(i).getBytes()); keyVector.setSafe(i, String.valueOf(i * 2).getBytes()); @@ -96,9 +136,9 @@ public void testParallelStringSearch() throws ExecutionException, InterruptedExc targetVector.setValueCount(VECTOR_LENGTH); keyVector.setValueCount(VECTOR_LENGTH); - ParallelSearcher searcher = new ParallelSearcher<>(targetVector, threadPool, THREAD_COUNT); + ParallelSearcher searcher = new ParallelSearcher<>(targetVector, threadPool, threadCount); for (int i = 0; i < VECTOR_LENGTH; i++) { - int pos = searcher.search(keyVector, i); + int pos = comparator == null ? searcher.search(keyVector, i) : searcher.search(keyVector, i, comparator); if (i * 2 < VECTOR_LENGTH) { assertEquals(i * 2, pos); } else {