From edbbfd6d55fd1cd1cfb14ed558945660fa4c0690 Mon Sep 17 00:00:00 2001 From: Kaival Parikh Date: Sat, 23 Dec 2023 14:01:38 +0000 Subject: [PATCH] Print count of reachable nodes, and score of unreachable nodes --- .../lucene/util/hnsw/HnswGraphSearcher.java | 37 +++++++++++++++++++ .../BaseVectorSimilarityQueryTestCase.java | 5 +++ 2 files changed, 42 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 46d6c93d52c3..2dd1c2d3ca59 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,6 +20,9 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Locale; +import java.util.Queue; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.BitSet; @@ -103,6 +106,40 @@ private static void search( int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector); if (ep != -1) { graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds); + + // knnCollector.k() is 1 for a VectorSimilarityCollector, do not run this snippet for other + // collectors + if (knnCollector.k() == 1) { + FixedBitSet visited = new FixedBitSet(graph.size()); + Queue candidates = new ArrayDeque<>(); + candidates.add(graph.entryNode()); + + while (!candidates.isEmpty()) { + int node = candidates.remove(); + if (visited.getAndSet(node)) { + continue; + } + graph.seek(0, node); + + int neighbor; + while ((neighbor = graph.nextNeighbor()) != NO_MORE_DOCS) { + candidates.add(neighbor); + } + } + + System.out.printf( + Locale.ROOT, + "Total Nodes = %3d, Reachable Nodes = %3d\n\nUnreachable Nodes:\n", + visited.length(), + visited.cardinality()); + for (int i = 0; i < visited.length(); i++) { + if (!visited.get(i)) { + System.out.printf( + Locale.ROOT, "Doc = %d, Score = %f\n", scorer.ordToDoc(i), scorer.score(i)); + } + } + System.out.println(); + } } } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java index c4f624bae741..a42a90af1e68 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseVectorSimilarityQueryTestCase.java @@ -374,6 +374,7 @@ public void testVectorsAboveSimilarity() throws IOException { // Find score above which we get (at least) numAccepted vectors float resultSimilarity = getSimilarity(vectors, queryVector, numAccepted); + System.out.printf(Locale.ROOT, "Similarity = %f\n", resultSimilarity); // Cache scores of vectors Map scores = new HashMap<>(); @@ -383,6 +384,7 @@ public void testVectorsAboveSimilarity() throws IOException { scores.put(i, score); } } + Map scoresCopy = new HashMap<>(scores); try (Directory indexStore = getIndexStore(vectors); IndexReader reader = DirectoryReader.open(indexStore)) { @@ -397,11 +399,14 @@ public void testVectorsAboveSimilarity() throws IOException { // Check that the collected result is above accepted similarity assertTrue(scores.containsKey(id)); + scoresCopy.remove(id); // Check that the score is correct assertEquals(scores.get(id), scoreDoc.score, delta); } + System.out.printf(Locale.ROOT, "Missed Docs = %s\n\n", scoresCopy); + // Check that all results are collected assertEquals(scores.size(), scoreDocs.length); }