From c0ad6488817a5c6aa2bf7e3ea2db6ef2b1b67b51 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 19 Dec 2024 10:32:59 -0600 Subject: [PATCH] add GraphIndexBuilder.rescore() --- .../jvector/graph/GraphIndexBuilder.java | 36 ++++++++++++ .../jvector/graph/GraphIndexBuilderTest.java | 55 ++++++++++++++++++- 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 045a7b413..3068aaf7d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -185,6 +185,42 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1))); } + public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { + var newBuilder = new GraphIndexBuilder(newProvider, + other.dimension, + other.graph.maxDegree(), + other.beamWidth, + other.neighborOverflow, + other.alpha, + other.simdExecutor, + other.parallelExecutor); + + // Copy each node and its neighbors from the old graph to the new one + for (int i = 0; i < other.graph.getIdUpperBound(); i++) { + if (!other.graph.containsNode(i)) { + continue; + } + + var neighbors = other.graph.getNeighbors(i); + var sf = newProvider.searchProviderFor(i).scoreFunction(); + var newNeighbors = new NodeArray(neighbors.size()); + + // Copy neighbors with new scores + for (var it = neighbors.iterator(); it.hasNext(); ) { + int neighbor = it.nextInt(); + // since we're using a different score provider, use insertSorted instead of addInOrder + newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor)); + } + + newBuilder.graph.addNode(i, newNeighbors); + } + + // Set the entry node + newBuilder.graph.updateEntryNode(other.graph.entry()); + + return newBuilder; + } + public OnHeapGraphIndex build(RandomAccessVectorValues ravv) { var vv = ravv.threadLocalSupplier(); int size = ravv.size(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java index dcec56ca8..bb4b3c159 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/GraphIndexBuilderTest.java @@ -20,15 +20,19 @@ import io.github.jbellis.jvector.LuceneTestCase; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.io.IOException; -import java.io.RandomAccessFile; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; import java.util.function.Supplier; import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors; @@ -37,6 +41,7 @@ @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class GraphIndexBuilderTest extends LuceneTestCase { + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); private Path testDirectory; @@ -50,6 +55,54 @@ public void tearDown() { TestUtil.deleteQuietly(testDirectory); } + @Test + public void testRescore() { + // Create test vectors where each vector is [node_id, 0] + var vectors = new ArrayList>(); + vectors.add(vts.createFloatVector(new float[] {0, 0})); + vectors.add(vts.createFloatVector(new float[] {0, 1})); + vectors.add(vts.createFloatVector(new float[] {2, 0})); + var ravv = new ListRandomAccessVectorValues(vectors, 2); + + // Initial score provider uses dot product, so scores will equal node IDs + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN); + var builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f); + + // Add 3 nodes + builder.addGraphNode(0, ravv.getVector(0)); + builder.addGraphNode(1, ravv.getVector(1)); + builder.addGraphNode(2, ravv.getVector(2)); + var neighbors = builder.graph.getNeighbors(0); + assertEquals(1, neighbors.getNode(0)); + assertEquals(2, neighbors.getNode(1)); + assertEquals(0.5f, neighbors.getScore(0), 1E-6); + assertEquals(0.2f, neighbors.getScore(1), 1E-6); + + // Create new vectors where each is [-node_id, 0] so dot products will be negative node IDs + vectors.clear(); + vectors.add(vts.createFloatVector(new float[] {0, 0})); + vectors.add(vts.createFloatVector(new float[] {0, 4})); + vectors.add(vts.createFloatVector(new float[] {2, 0})); + + // Rescore the graph + // (The score provider didn't change, but the vectors did, which provides the same effect) + var rescored = GraphIndexBuilder.rescore(builder, bsp); + + // Verify edges still exist + var newGraph = rescored.getGraph(); + assertTrue(newGraph.containsNode(0)); + assertTrue(newGraph.containsNode(1)); + assertTrue(newGraph.containsNode(2)); + + // Check node 0's neighbors, score and order should be different + var newNeighbors = newGraph.getNeighbors(0); + assertEquals(2, newNeighbors.getNode(0)); + assertEquals(1, newNeighbors.getNode(1)); + assertEquals(0.2f, newNeighbors.getScore(0), 1E-6); + assertEquals(0.05882353f, newNeighbors.getScore(1), 1E-6); + + } + @Test public void testSaveAndLoad() throws IOException { int dimension = randomIntBetween(2, 32);