Skip to content

Commit

Permalink
ARROW-11999: [Java] Support parallel vector element search with user-…
Browse files Browse the repository at this point in the history
…specified comparator

This is in response to the discussion in #5631 (comment)

Currently, we only support parallel search with `RangeEqualsVisitor`, which does not support user-specified comparators.
We want to provide the functionality in this issue to support wider range of use cases.

Closes #9736 from liyafan82/fly_0317_par

Authored-by: liyafan82 <fan_li_ya@foxmail.com>
Signed-off-by: Micah Kornfield <emkornfield@gmail.com>
  • Loading branch information
liyafan82 authored and emkornfield committed Apr 16, 2021
1 parent b2ceb8f commit 715cb57
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,19 @@ public ParallelSearcher(V vector, ExecutorService threadPool, int numThreads) {
this.numThreads = numThreads;
}

private CompletableFuture<Boolean>[] initSearch() {
keyPosition = -1;
final CompletableFuture<Boolean>[] futures = new CompletableFuture[numThreads];
for (int i = 0; i < futures.length; i++) {
futures[i] = new CompletableFuture<>();
}
return futures;
}

/**
* Search for the key in the target vector.
* Search for the key in the target vector. The element-wise comparison is based on
* {@link RangeEqualsVisitor}, so there are two possible results for each element-wise
* comparison: equal and un-equal.
* @param keyVector the vector containing the search key.
* @param keyIndex the index of the search key in the key vector.
* @return the position of a matched value in the target vector,
Expand All @@ -80,13 +91,8 @@ public ParallelSearcher(V vector, ExecutorService threadPool, int numThreads) {
* @throws InterruptedException if a thread is interrupted.
*/
public int search(V keyVector, int keyIndex) throws ExecutionException, InterruptedException {
keyPosition = -1;
final CompletableFuture<Boolean>[] futures = initSearch();
final int valueCount = vector.getValueCount();
final CompletableFuture<Boolean>[] futures = new CompletableFuture[numThreads];
for (int i = 0; i < futures.length; i++) {
futures[i] = new CompletableFuture<>();
}

for (int i = 0; i < numThreads; i++) {
final int tid = i;
threadPool.submit(() -> {
Expand Down Expand Up @@ -124,4 +130,61 @@ public int search(V keyVector, int keyIndex) throws ExecutionException, Interrup
CompletableFuture.allOf(futures).get();
return keyPosition;
}

/**
* Search for the key in the target vector. The element-wise comparison is based on
* {@link VectorValueComparator}, so there are three possible results for each element-wise
* comparison: less than, equal to and greater than.
* @param keyVector the vector containing the search key.
* @param keyIndex the index of the search key in the key vector.
* @param comparator the comparator for comparing the key against vector elements.
* @return the position of a matched value in the target vector,
* or -1 if none is found. Please note that if there are multiple
* matches of the key in the target vector, this method makes no
* guarantees about which instance is returned.
* For an alternative search implementation that always finds the first match of the key,
* see {@link VectorSearcher#linearSearch(ValueVector, VectorValueComparator, ValueVector, int)}.
* @throws ExecutionException if an exception occurs in a thread.
* @throws InterruptedException if a thread is interrupted.
*/
public int search(
V keyVector, int keyIndex, VectorValueComparator<V> comparator) throws ExecutionException, InterruptedException {
final CompletableFuture<Boolean>[] futures = initSearch();
final int valueCount = vector.getValueCount();
for (int i = 0; i < numThreads; i++) {
final int tid = i;
threadPool.submit(() -> {
// convert to long to avoid overflow
int start = (int) (((long) valueCount) * tid / numThreads);
int end = (int) ((long) valueCount) * (tid + 1) / numThreads;

if (start >= end) {
// no data assigned to this task.
futures[tid].complete(false);
return;
}

VectorValueComparator<V> localComparator = comparator.createNew();
localComparator.attachVectors(vector, keyVector);
for (int pos = start; pos < end; pos++) {
if (keyPosition != -1) {
// the key has been found by another task
futures[tid].complete(false);
return;
}
if (localComparator.compare(pos, keyIndex) == 0) {
keyPosition = pos;
futures[tid].complete(true);
return;
}
}

// no match value is found.
futures[tid].complete(false);
});
}

CompletableFuture.allOf(futures).get();
return keyPosition;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,67 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.apache.arrow.algorithm.sort.DefaultVectorComparators;
import org.apache.arrow.algorithm.sort.VectorValueComparator;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

/**
* Test cases for {@link ParallelSearcher}.
*/
@RunWith(Parameterized.class)
public class TestParallelSearcher {

private static final int THREAD_COUNT = 10;
private enum ComparatorType {
EqualityComparator,
OrderingComparator;
}

private static final int VECTOR_LENGTH = 10000;

private final int threadCount;

private BufferAllocator allocator;

private ExecutorService threadPool;

private final ComparatorType comparatorType;

public TestParallelSearcher(ComparatorType comparatorType, int threadCount) {
this.comparatorType = comparatorType;
this.threadCount = threadCount;
}

@Parameterized.Parameters(name = "comparator type = {0}, thread count = {1}")
public static Collection<Object[]> getComparatorName() {
List<Object[]> params = new ArrayList<>();
int[] threadCounts = {1, 2, 5, 10, 20, 50};
for (ComparatorType type : ComparatorType.values()) {
for (int count : threadCounts) {
params.add(new Object[] {type, count});
}
}
return params;
}

@Before
public void prepare() {
allocator = new RootAllocator(1024 * 1024);
threadPool = Executors.newFixedThreadPool(THREAD_COUNT);
threadPool = Executors.newFixedThreadPool(threadCount);
}

@After
Expand All @@ -63,16 +95,20 @@ public void testParallelIntSearch() throws ExecutionException, InterruptedExcept
targetVector.allocateNew(VECTOR_LENGTH);
keyVector.allocateNew(VECTOR_LENGTH);

// if we are comparing elements using equality semantics, we do not need a comparator here.
VectorValueComparator<IntVector> comparator = comparatorType == ComparatorType.EqualityComparator ? null
: DefaultVectorComparators.createDefaultComparator(targetVector);

for (int i = 0; i < VECTOR_LENGTH; i++) {
targetVector.set(i, i);
keyVector.set(i, i * 2);
}
targetVector.setValueCount(VECTOR_LENGTH);
keyVector.setValueCount(VECTOR_LENGTH);

ParallelSearcher<IntVector> searcher = new ParallelSearcher<>(targetVector, threadPool, THREAD_COUNT);
ParallelSearcher<IntVector> searcher = new ParallelSearcher<>(targetVector, threadPool, threadCount);
for (int i = 0; i < VECTOR_LENGTH; i++) {
int pos = searcher.search(keyVector, i);
int pos = comparator == null ? searcher.search(keyVector, i) : searcher.search(keyVector, i, comparator);
if (i * 2 < VECTOR_LENGTH) {
assertEquals(i * 2, pos);
} else {
Expand All @@ -89,16 +125,20 @@ public void testParallelStringSearch() throws ExecutionException, InterruptedExc
targetVector.allocateNew(VECTOR_LENGTH);
keyVector.allocateNew(VECTOR_LENGTH);

// if we are comparing elements using equality semantics, we do not need a comparator here.
VectorValueComparator<VarCharVector> comparator = comparatorType == ComparatorType.EqualityComparator ? null
: DefaultVectorComparators.createDefaultComparator(targetVector);

for (int i = 0; i < VECTOR_LENGTH; i++) {
targetVector.setSafe(i, String.valueOf(i).getBytes());
keyVector.setSafe(i, String.valueOf(i * 2).getBytes());
}
targetVector.setValueCount(VECTOR_LENGTH);
keyVector.setValueCount(VECTOR_LENGTH);

ParallelSearcher<VarCharVector> searcher = new ParallelSearcher<>(targetVector, threadPool, THREAD_COUNT);
ParallelSearcher<VarCharVector> searcher = new ParallelSearcher<>(targetVector, threadPool, threadCount);
for (int i = 0; i < VECTOR_LENGTH; i++) {
int pos = searcher.search(keyVector, i);
int pos = comparator == null ? searcher.search(keyVector, i) : searcher.search(keyVector, i, comparator);
if (i * 2 < VECTOR_LENGTH) {
assertEquals(i * 2, pos);
} else {
Expand Down

0 comments on commit 715cb57

Please sign in to comment.