Skip to content

Commit

Permalink
gh-12627: HnswGraphBuilder connects disconnected HNSW graph components (
Browse files Browse the repository at this point in the history
#13566)

* gh-12627: HnswGraphBuilder connects disconnected HNSW graph components
  • Loading branch information
msokolov authored Aug 8, 2024
1 parent d26b152 commit 2178287
Show file tree
Hide file tree
Showing 11 changed files with 679 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ public void seek(int level, int targetOrd) throws IOException {
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
assert arcCount <= currentNeighborsBuffer.length : "too many neighbors: " + arcCount;
if (arcCount > 0) {
currentNeighborsBuffer[0] = dataIn.readVInt();
for (int i = 1; i < arcCount; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ public T copyValue(T vectorValue) {
throw new UnsupportedOperationException();
}

OnHeapHnswGraph getGraph() {
OnHeapHnswGraph getGraph() throws IOException {
assert flatFieldVectorsWriter.isFinished();
if (node > 0) {
return hnswGraphBuilder.getCompletedGraph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ public interface HnswBuilder {
* components, re-ordering node ids for better delta compression) may be triggered, so callers
* should expect this call to take some time.
*/
OnHeapHnswGraph getCompletedGraph();
OnHeapHnswGraph getCompletedGraph() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public OnHeapHnswGraph build(int maxOrd) throws IOException {
});
}
taskExecutor.invokeAll(futures);
finish();
frozen = true;
return workers[0].getCompletedGraph();
}
Expand All @@ -109,11 +110,19 @@ public void setInfoStream(InfoStream infoStream) {
}

@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
public OnHeapHnswGraph getCompletedGraph() throws IOException {
if (frozen == false) {
// should already have been called in build(), but just in case
finish();
frozen = true;
}
return getGraph();
}

private void finish() throws IOException {
workers[0].finish();
}

@Override
public OnHeapHnswGraph getGraph() {
return workers[0].getGraph();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
package org.apache.lucene.util.hnsw;

import static java.lang.Math.log;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
Expand All @@ -28,6 +31,7 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.HnswUtil.Component;

/**
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
Expand Down Expand Up @@ -137,7 +141,7 @@ protected HnswGraphBuilder(
HnswGraphSearcher graphSearcher)
throws IOException {
if (M <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
throw new IllegalArgumentException("M (max connections) must be positive");
}
if (beamWidth <= 0) {
throw new IllegalArgumentException("beamWidth must be positive");
Expand Down Expand Up @@ -173,8 +177,11 @@ public void setInfoStream(InfoStream infoStream) {
}

@Override
public OnHeapHnswGraph getCompletedGraph() {
frozen = true;
public OnHeapHnswGraph getCompletedGraph() throws IOException {
if (!frozen) {
finish();
frozen = true;
}
return getGraph();
}

Expand Down Expand Up @@ -405,6 +412,93 @@ private static int getRandomGraphLevel(double ml, SplittableRandom random) {
return ((int) (-log(randDouble) * ml));
}

void finish() throws IOException {
connectComponents();
}

private void connectComponents() throws IOException {
long start = System.nanoTime();
for (int level = 0; level < hnsw.numLevels(); level++) {
if (connectComponents(level) == false) {
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "connectComponents failed on level " + level);
}
}
}
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(
HNSW_COMPONENT, "connectComponents " + (System.nanoTime() - start) / 1_000_000 + " ms");
}
}

private boolean connectComponents(int level) throws IOException {
FixedBitSet notFullyConnected = new FixedBitSet(hnsw.size());
int maxConn = M;
if (level == 0) {
maxConn *= 2;
}
List<Component> components = HnswUtil.components(hnsw, level, notFullyConnected, maxConn);
boolean result = true;
if (components.size() > 1) {
// connect other components to the largest one
Component c0 = components.stream().max(Comparator.comparingInt(Component::size)).get();
if (c0.start() == NO_MORE_DOCS) {
// the component is already fully connected - no room for new connections
return false;
}
// try for more connections? We only do one since otherwise they may become full
// while linking
GraphBuilderKnnCollector beam = new GraphBuilderKnnCollector(1);
int[] eps = new int[1];
for (Component c : components) {
if (c != c0) {
beam.clear();
eps[0] = c0.start();
RandomVectorScorer scorer = scorerSupplier.scorer(c.start());
// find the closest node in the largest component to the lowest-numbered node in this
// component that has room to make a connection
graphSearcher.searchLevel(beam, scorer, 0, eps, hnsw, notFullyConnected);
boolean linked = false;
while (beam.size() > 0) {
float score = beam.minimumScore();
int c0node = beam.popNode();
assert notFullyConnected.get(c0node);
// link the nodes
link(level, c0node, c.start(), score, notFullyConnected);
linked = true;
}
if (!linked) {
result = false;
}
}
}
}
return result;
}

// Try to link two nodes bidirectionally; the forward connection will always be made.
// Update notFullyConnected.
private void link(int level, int n0, int n1, float score, FixedBitSet notFullyConnected) {
NeighborArray nbr0 = hnsw.getNeighbors(level, n0);
NeighborArray nbr1 = hnsw.getNeighbors(level, n1);
// must subtract 1 here since the nodes array is one larger than the configured
// max neighbors (M / 2M).
// We should have taken care of this check by searching for not-full nodes
int maxConn = nbr0.nodes().length - 1;
assert notFullyConnected.get(n0);
assert nbr0.size() < maxConn : "node " + n0 + " is full, has " + nbr0.size() + " friends";
nbr0.addOutOfOrder(n1, score);
if (nbr0.size() == maxConn) {
notFullyConnected.clear(n0);
}
if (nbr1.size() < maxConn) {
nbr1.addOutOfOrder(n0, score);
if (nbr1.size() == maxConn) {
notFullyConnected.clear(n1);
}
}
}

/**
* A restricted, specialized knnCollector that can be used when building a graph.
*
Expand Down
Loading

0 comments on commit 2178287

Please sign in to comment.