Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-thread searchability to OnHeapHnswGraph #12257

Merged
merged 1 commit into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ API Changes

New Features
---------------------
(No changes)

* GITHUB#12257: Create OnHeapHnswGraphSearcher to let OnHeapHnswGraph to be searched in a thread-safety manner. (Patrick Zhai)

Improvements
---------------------
Expand Down
145 changes: 122 additions & 23 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,31 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
NeighborQueue results;
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

int initialEp = graph.entryNode();
if (initialEp == -1) {
return new NeighborQueue(1, true);
}
int[] eps = new int[] {initialEp};
int numVisited = 0;
for (int level = graph.numLevels() - 1; level >= 1; level--) {
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
numVisited += results.visitedCount();
visitedLimit -= results.visitedCount();
if (results.incomplete()) {
results.setVisitedCount(numVisited);
return results;
}
eps[0] = results.pop();
}
results =
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
results.setVisitedCount(results.visitedCount() + numVisited);
return results;
/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(float[], int, RandomAccessVectorValues, VectorEncoding,
* VectorSimilarityFunction, HnswGraph, Bits, int)}
*/
public static NeighborQueue search(
float[] query,
int topK,
RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<float[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
Expand Down Expand Up @@ -161,6 +164,46 @@ public static NeighborQueue search(
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

/**
* Search {@link OnHeapHnswGraph}, this method is thread safe, for parameters please refer to
* {@link #search(byte[], int, RandomAccessVectorValues, VectorEncoding, VectorSimilarityFunction,
* HnswGraph, Bits, int)}
*/
public static NeighborQueue search(
byte[] query,
int topK,
RandomAccessVectorValues<byte[]> vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
OnHeapHnswGraph graph,
Bits acceptOrds,
int visitedLimit)
throws IOException {
OnHeapHnswGraphSearcher<byte[]> graphSearcher =
new OnHeapHnswGraphSearcher<>(
vectorEncoding,
similarityFunction,
new NeighborQueue(topK, true),
new SparseFixedBitSet(vectors.size()));
return search(query, topK, vectors, graph, graphSearcher, acceptOrds, visitedLimit);
}

private static <T> NeighborQueue search(
T query,
int topK,
RandomAccessVectorValues<T> vectors,
HnswGraph graph,
HnswGraphSearcher<T> graphSearcher,
Bits acceptOrds,
int visitedLimit)
throws IOException {
int initialEp = graph.entryNode();
if (initialEp == -1) {
return new NeighborQueue(1, true);
}
NeighborQueue results;
int[] eps = new int[] {graph.entryNode()};
int numVisited = 0;
Expand Down Expand Up @@ -252,9 +295,9 @@ private NeighborQueue searchLevel(
}

int topCandidateNode = candidates.pop();
graph.seek(level, topCandidateNode);
graphSeek(graph, level, topCandidateNode);
int friendOrd;
while ((friendOrd = graph.nextNeighbor()) != NO_MORE_DOCS) {
while ((friendOrd = graphNextNeighbor(graph)) != NO_MORE_DOCS) {
assert friendOrd < size : "friendOrd=" + friendOrd + "; size=" + size;
if (visited.getAndSet(friendOrd)) {
continue;
Expand Down Expand Up @@ -298,4 +341,60 @@ private void prepareScratchState(int capacity) {
}
visited.clear(0, visited.length());
}

/**
* Seek a specific node in the given graph. The default implementation will just call {@link
* HnswGraph#seek(int, int)}
*
* @throws IOException when seeking the graph
*/
void graphSeek(HnswGraph graph, int level, int targetNode) throws IOException {
graph.seek(level, targetNode);
}

/**
* Get the next neighbor from the graph, you must call {@link #graphSeek(HnswGraph, int, int)}
* before calling this method. The default implementation will just call {@link
* HnswGraph#nextNeighbor()}
*
* @return see {@link HnswGraph#nextNeighbor()}
* @throws IOException when advance neighbors
*/
int graphNextNeighbor(HnswGraph graph) throws IOException {
return graph.nextNeighbor();
}

/**
* This class allow {@link OnHeapHnswGraph} to be searched in a thread-safe manner.
*
* <p>Note the class itself is NOT thread safe, but since each search will create one new graph
* searcher the search method is thread safe.
*/
private static class OnHeapHnswGraphSearcher<C> extends HnswGraphSearcher<C> {

private NeighborArray cur;
private int upto;

private OnHeapHnswGraphSearcher(
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
NeighborQueue candidates,
BitSet visited) {
super(vectorEncoding, similarityFunction, candidates, visited);
}

@Override
void graphSeek(HnswGraph graph, int level, int targetNode) {
cur = ((OnHeapHnswGraph) graph).getNeighbors(level, targetNode);
upto = -1;
}

@Override
int graphNextNeighbor(HnswGraph graph) {
if (++upto < cur.size()) {
return cur.node[upto];
}
return NO_MORE_DOCS;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95Codec;
Expand Down Expand Up @@ -67,6 +73,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.NamedThreadFactory;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
Expand Down Expand Up @@ -991,6 +998,105 @@ public void testRandom() throws IOException {
assertTrue("overlap=" + overlap, overlap > 0.9);
}

/* test thread-safety of searching OnHeapHnswGraph */
@SuppressWarnings("unchecked")
public void testOnHeapHnswGraphSearch()
throws IOException, ExecutionException, InterruptedException, TimeoutException {
int size = atLeast(100);
int dim = atLeast(10);
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
int topK = 5;
HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);

List<T> queries = new ArrayList<>();
List<NeighborQueue> expects = new ArrayList<>();
for (int i = 0; i < 100; i++) {
NeighborQueue expect;
T query = randomVector(dim);
queries.add(query);
expect =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};

while (expect.size() > topK) {
expect.pop();
}
expects.add(expect);
}

ExecutorService exec =
Executors.newFixedThreadPool(4, new NamedThreadFactory("onHeapHnswSearch"));
List<Future<NeighborQueue>> futures = new ArrayList<>();
for (T query : queries) {
futures.add(
exec.submit(
() -> {
NeighborQueue actual;
try {
actual =
switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search(
(byte[]) query,
100,
(RandomAccessVectorValues<byte[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search(
(float[]) query,
100,
(RandomAccessVectorValues<float[]>) vectors,
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
};
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
while (actual.size() > topK) {
actual.pop();
}
return actual;
}));
}
List<NeighborQueue> actuals = new ArrayList<>();
for (Future<NeighborQueue> future : futures) {
actuals.add(future.get(10, TimeUnit.SECONDS));
}
exec.shutdownNow();
for (int i = 0; i < expects.size(); i++) {
NeighborQueue expect = expects.get(i);
NeighborQueue actual = actuals.get(i);
assertArrayEquals(expect.nodes(), actual.nodes());
}
}

private int computeOverlap(int[] a, int[] b) {
Arrays.sort(a);
Arrays.sort(b);
Expand Down