Skip to content

Commit

Permalink
add GraphIndexBuilder.rescore()
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Dec 19, 2024
1 parent d6f6d2a commit c0ad648
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,42 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider,
this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
}

public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
var newBuilder = new GraphIndexBuilder(newProvider,
other.dimension,
other.graph.maxDegree(),
other.beamWidth,
other.neighborOverflow,
other.alpha,
other.simdExecutor,
other.parallelExecutor);

// Copy each node and its neighbors from the old graph to the new one
for (int i = 0; i < other.graph.getIdUpperBound(); i++) {
if (!other.graph.containsNode(i)) {
continue;
}

var neighbors = other.graph.getNeighbors(i);
var sf = newProvider.searchProviderFor(i).scoreFunction();
var newNeighbors = new NodeArray(neighbors.size());

// Copy neighbors with new scores
for (var it = neighbors.iterator(); it.hasNext(); ) {
int neighbor = it.nextInt();
// since we're using a different score provider, use insertSorted instead of addInOrder
newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor));
}

newBuilder.graph.addNode(i, newNeighbors);
}

// Set the entry node
newBuilder.graph.updateEntryNode(other.graph.entry());

return newBuilder;
}

public OnHeapGraphIndex build(RandomAccessVectorValues ravv) {
var vv = ravv.threadLocalSupplier();
int size = ravv.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@
import io.github.jbellis.jvector.LuceneTestCase;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.disk.SimpleMappedReader;
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorizationProvider;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.function.Supplier;

import static io.github.jbellis.jvector.graph.TestVectorGraph.createRandomFloatVectors;
Expand All @@ -37,6 +41,7 @@

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class GraphIndexBuilderTest extends LuceneTestCase {
private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();

private Path testDirectory;

Expand All @@ -50,6 +55,54 @@ public void tearDown() {
TestUtil.deleteQuietly(testDirectory);
}

@Test
public void testRescore() {
// Create test vectors where each vector is [node_id, 0]
var vectors = new ArrayList<VectorFloat<?>>();
vectors.add(vts.createFloatVector(new float[] {0, 0}));
vectors.add(vts.createFloatVector(new float[] {0, 1}));
vectors.add(vts.createFloatVector(new float[] {2, 0}));
var ravv = new ListRandomAccessVectorValues(vectors, 2);

// Initial score provider uses dot product, so scores will equal node IDs
var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN);
var builder = new GraphIndexBuilder(bsp, 2, 2, 10, 1.0f, 1.0f);

// Add 3 nodes
builder.addGraphNode(0, ravv.getVector(0));
builder.addGraphNode(1, ravv.getVector(1));
builder.addGraphNode(2, ravv.getVector(2));
var neighbors = builder.graph.getNeighbors(0);
assertEquals(1, neighbors.getNode(0));
assertEquals(2, neighbors.getNode(1));
assertEquals(0.5f, neighbors.getScore(0), 1E-6);
assertEquals(0.2f, neighbors.getScore(1), 1E-6);

// Create new vectors where each is [-node_id, 0] so dot products will be negative node IDs
vectors.clear();
vectors.add(vts.createFloatVector(new float[] {0, 0}));
vectors.add(vts.createFloatVector(new float[] {0, 4}));
vectors.add(vts.createFloatVector(new float[] {2, 0}));

// Rescore the graph
// (The score provider didn't change, but the vectors did, which provides the same effect)
var rescored = GraphIndexBuilder.rescore(builder, bsp);

// Verify edges still exist
var newGraph = rescored.getGraph();
assertTrue(newGraph.containsNode(0));
assertTrue(newGraph.containsNode(1));
assertTrue(newGraph.containsNode(2));

// Check node 0's neighbors, score and order should be different
var newNeighbors = newGraph.getNeighbors(0);
assertEquals(2, newNeighbors.getNode(0));
assertEquals(1, newNeighbors.getNode(1));
assertEquals(0.2f, newNeighbors.getScore(0), 1E-6);
assertEquals(0.05882353f, newNeighbors.getScore(1), 1E-6);

}

@Test
public void testSaveAndLoad() throws IOException {
int dimension = randomIntBetween(2, 32);
Expand Down

0 comments on commit c0ad648

Please sign in to comment.