Skip to content

Commit

Permalink
Print count of reachable nodes, and score of unreachable nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaival Parikh committed Dec 23, 2023
1 parent dc9f154 commit edbbfd6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Locale;
import java.util.Queue;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -103,6 +106,40 @@ private static void search(
int ep = graphSearcher.findBestEntryPoint(scorer, graph, knnCollector);
if (ep != -1) {
graphSearcher.searchLevel(knnCollector, scorer, 0, new int[] {ep}, graph, acceptOrds);

// knnCollector.k() is 1 for a VectorSimilarityCollector, do not run this snippet for other
// collectors
if (knnCollector.k() == 1) {
FixedBitSet visited = new FixedBitSet(graph.size());
Queue<Integer> candidates = new ArrayDeque<>();
candidates.add(graph.entryNode());

while (!candidates.isEmpty()) {
int node = candidates.remove();
if (visited.getAndSet(node)) {
continue;
}
graph.seek(0, node);

int neighbor;
while ((neighbor = graph.nextNeighbor()) != NO_MORE_DOCS) {
candidates.add(neighbor);
}
}

System.out.printf(
Locale.ROOT,
"Total Nodes = %3d, Reachable Nodes = %3d\n\nUnreachable Nodes:\n",
visited.length(),
visited.cardinality());
for (int i = 0; i < visited.length(); i++) {
if (!visited.get(i)) {
System.out.printf(
Locale.ROOT, "Doc = %d, Score = %f\n", scorer.ordToDoc(i), scorer.score(i));
}
}
System.out.println();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ public void testVectorsAboveSimilarity() throws IOException {

// Find score above which we get (at least) numAccepted vectors
float resultSimilarity = getSimilarity(vectors, queryVector, numAccepted);
System.out.printf(Locale.ROOT, "Similarity = %f\n", resultSimilarity);

// Cache scores of vectors
Map<Integer, Float> scores = new HashMap<>();
Expand All @@ -383,6 +384,7 @@ public void testVectorsAboveSimilarity() throws IOException {
scores.put(i, score);
}
}
Map<Integer, Float> scoresCopy = new HashMap<>(scores);

try (Directory indexStore = getIndexStore(vectors);
IndexReader reader = DirectoryReader.open(indexStore)) {
Expand All @@ -397,11 +399,14 @@ public void testVectorsAboveSimilarity() throws IOException {

// Check that the collected result is above accepted similarity
assertTrue(scores.containsKey(id));
scoresCopy.remove(id);

// Check that the score is correct
assertEquals(scores.get(id), scoreDoc.score, delta);
}

System.out.printf(Locale.ROOT, "Missed Docs = %s\n\n", scoresCopy);

// Check that all results are collected
assertEquals(scores.size(), scoreDocs.length);
}
Expand Down

0 comments on commit edbbfd6

Please sign in to comment.