diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index b794244105e8..1ff593b6616f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -20,6 +20,7 @@ import static java.lang.Math.log; import java.io.IOException; +import java.util.HashSet; import java.util.Locale; import java.util.Objects; import java.util.SplittableRandom; @@ -252,7 +253,20 @@ to the newly introduced levels (repeating step 2,3 for new levels) and again try // then do connections from bottom up for (int i = 0; i < scratchPerLevel.length; i++) { - addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); // baseline : similar to false, false, false below. +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false, false, false); // baseline_equivalent +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false, false, true); // baseline with remove other half +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, false, false, false); // exp-1 extendCandidates = true +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, false, false); // exp-2 with keep-pruned =true +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, true, false); // exp-3 with keep pruned till half max-conn +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, true, true); // exp-3 with keep pruned till half max-conn and remove other half +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true, false, false); +// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true, true, false); + + // new heuristic +// addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false); // new heuristic without remove other half +// addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], true, false); // new heuristic with remove other half + addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true); // new heuristic with honour-max-conn true } lowestUnsetLevel += scratchPerLevel.length; assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1; @@ -290,6 +304,233 @@ private long printGraphBuildStatus(int node, long start, long t) { return now; } + /** + * Find first non-diverse neighbour among the list of neighbors starting from the most distant + * neighbours + */ + private int findWorstNonDiverse(NeighborArray neighbors, int level, RandomVectorScorer scorer, boolean removeRandom) throws IOException { + int[] uncheckedIndexes = neighbors.sort(scorer); // what does 'unchecked node' means? What check? What + int maxCommonConnectionCount = 0; + int maxcccIndex = -1; + int maxConnectionIndex = -1; + int maxConnectionCount = 2; // has atleast 2 connections + for (int i = neighbors.size() - 1; i >= 0 ; i--) { + int currNode = neighbors.nodes()[i]; + float currNodeScore = neighbors.scores()[i]; + NeighborArray currNodeNeighbours = hnsw.getNeighbors(level, currNode); + NeighborArray commonNeighbours = findCommon(neighbors, currNodeNeighbours); + + if (commonNeighbours.size() > maxCommonConnectionCount) { + maxCommonConnectionCount = commonNeighbours.size(); + maxcccIndex = i; + } + if (currNodeNeighbours.size() >= maxConnectionCount) { + maxConnectionCount = currNodeNeighbours.size(); + maxConnectionIndex = i; + } + + for (int j = 0; j < commonNeighbours.size(); j++) { + if (commonNeighbours.scores()[j] > currNodeScore) { + return i; // currNode is non-diverse as it's score with another neighbour is higher + } + } + } + if (maxcccIndex != -1) { + return maxcccIndex; + } else if (maxConnectionIndex != -1) { + return maxConnectionIndex; + } else { + if (removeRandom == false) { + return -1; + } else { + return random.nextInt(neighbors.size()); + } + } + } + + public int addAndEnsureConnectedDiversity(int newNode, float newScore, int nodeId, NeighborArray neighbours, + int level, int maxConnOnLevel, boolean removeRandom) + throws IOException { + neighbours.addOutOfOrder(newNode, newScore); + if (neighbours.size() <= maxConnOnLevel) { + return -1; // none removed + } + RandomVectorScorer scorer = scorerSupplier.scorer(nodeId); + int indexToRemove = findWorstNonDiverse(neighbours, level, scorer, removeRandom); + if (indexToRemove != -1) { + int nodeRemoved = neighbours.nodes()[indexToRemove]; + neighbours.removeIndex(indexToRemove); + return nodeRemoved; + } else { + if (removeRandom == true) { + throw new IllegalStateException("If remove random is set we should not have got -1 for remove index."); + } else { + return -1; // no diverse found hence not removed + } + } + } + + private NeighborArray findCommon(NeighborArray neighbors, NeighborArray currNodeNeighbours) { + NeighborArray common = new NeighborArray(Math.max(currNodeNeighbours.size(), neighbors.size()), true); + for (int i = 0; i < neighbors.size(); i++) { + for (int j = 0; j < currNodeNeighbours.size(); j++) { + if (neighbors.nodes()[i] == currNodeNeighbours.nodes()[j]) { + common.addOutOfOrder(currNodeNeighbours.nodes()[j], currNodeNeighbours.scores()[j]); + } + } + } + + return common; + } + + + private void addDiverseNeighborsNewHeuristic(int level, int incomingNode, NeighborArray candidates, boolean removeOtherHalf, + boolean honourMaxConn) + throws IOException { + /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it + * is closer to target than it is to any of the already-selected neighbors (ie selected in this method, + * since the node is new and has no prior neighbors). + */ + NeighborArray neighbors = hnsw.getNeighbors(level, incomingNode); + assert neighbors.size() == 0; // new node + + int maxConnOnLevel = level == 0 ? M * 2 : M; + + boolean[] mask = selectAll(neighbors, candidates, maxConnOnLevel); + // Link the selected nodes to the new node, and the new node to the selected nodes (again + // applying diversity heuristic) + int size = candidates.size(); + for (int i = 0; i < size; i++) { + if (mask[i] == false) { + continue; + } + + int nbr = candidates.nodes()[i]; + int nodeRemoved = -1; + NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.rwlock.writeLock().lock(); + try { + nodeRemoved = addAndEnsureConnectedDiversity(incomingNode, candidates.scores()[i], nbr, nbrsOfNbr, level, maxConnOnLevel, honourMaxConn); + } finally { + nbrsOfNbr.rwlock.writeLock().unlock(); + } + + if (removeOtherHalf && nodeRemoved != -1) { + // it is fine to remove it even from current node as we are iterating over candidates and not over neighbours + // Also, there is no chance of deadlock as we are not holding more than one lock at a time. + removeOtherHalf(nodeRemoved, nbr, level); + } + } + + } + + private boolean[] selectAll( NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) { + boolean[] selected = new boolean[candidates.size()]; + + for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) { + int cNode = candidates.nodes()[i]; + float cScore = candidates.scores()[i]; + selected[i] = true; + neighbors.addInOrder(cNode, cScore); + } + + return selected; + } + + private void removeOtherHalf(int fromNode, int nbr, int level) { + NeighborArray neighbors = hnsw.getNeighbors(level, fromNode); + neighbors.rwlock.writeLock().lock(); + neighbors.removeNode(nbr); + neighbors.rwlock.writeLock().unlock(); + } + + + private void addDiverseNeighbors(int level, int node, NeighborArray candidates, + boolean extendCandidates, boolean keepPrunedConnections, + boolean keepHalfPrunedConnection, boolean removeOtherHalf) throws IOException { + NeighborArray neighbors = hnsw.getNeighbors(level, node); + assert neighbors.size() == 0; // new node + + int maxConnOnLevel = level == 0 ? M * 2 : M; + NeighborArray originalCandidates = candidates; + RandomVectorScorer scorer = scorerSupplier.scorer(node); + + if (extendCandidates) { + HashSet addedNodes = new HashSet<>(); + + for (int i = 0; i < originalCandidates.size(); i++) { + int cand = originalCandidates.nodes()[i]; + addedNodes.add(cand); + addAllNeighbours(cand, level, addedNodes); + } + + candidates = new NeighborArray(addedNodes.size(), originalCandidates.isScoresDescOrder()); + for (int cand : addedNodes) { + candidates.addOutOfOrder(cand, Float.NaN); + } + + candidates.sort(scorer); + } + + + boolean[] selected = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel); + + if (keepPrunedConnections) { + int maskedIndex = selected.length - 1; // start from end as the highest scoring candidates are at the end + int maxConn = maxConnOnLevel; + if (keepHalfPrunedConnection) { + maxConn = maxConnOnLevel/2; + } + while (neighbors.size() < maxConn && maskedIndex >= 0) { + if (selected[maskedIndex] == false) { + neighbors.addOutOfOrder(candidates.nodes()[maskedIndex], candidates.scores()[maskedIndex]); + selected[maskedIndex] = true; + } + maskedIndex--; + } + neighbors.sort(scorer); + } + + // Link the selected nodes to the new node, and the new node to the selected nodes (again + // applying diversity heuristic) + // NOTE: here we're using candidates and mask but not the neighbour array because once we have + // added incoming link there will be possibilities of this node being discovered and neighbour + // array being modified. So using local candidates and mask is a safer option. + int nodeRemoved = -1; + for (int i = 0; i < candidates.size(); i++) { + if (selected[i] == false) { + continue; + } + int nbr = candidates.nodes()[i]; + NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr); + nbrsOfNbr.rwlock.writeLock().lock(); + try { + nodeRemoved = nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier); + } finally { + nbrsOfNbr.rwlock.writeLock().unlock(); + } + + if (removeOtherHalf && nodeRemoved != -1) { + removeOtherHalf(nodeRemoved, nbr, level); + } + } + } + + /** + * Adds all neighbours of `node` into the `addedNodes` set + * + * @param node + * @param addedNodes + */ + private void addAllNeighbours(int node, int level, HashSet addedNodes) { + NeighborArray neighbors = hnsw.getNeighbors(level, node); + + // TODO: I think we need to guard with lock here as the `neighbors` array might get modified while we are iterating. + for (int i = 0; i < neighbors.size(); i++) { + addedNodes.add(neighbors.nodes()[i]); + } + } + private void addDiverseNeighbors(int level, int node, NeighborArray candidates) throws IOException { /* For each of the beamWidth nearest candidates (going from best to worst), select it only if it @@ -327,7 +568,11 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates) */ private boolean[] selectAndLinkDiverse( NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException { - boolean[] mask = new boolean[candidates.size()]; + boolean[] selected = new boolean[candidates.size()]; + return selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, selected); + } + private boolean[] selectAndLinkDiverse( + NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel, boolean[] selected) throws IOException { // Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) { // compare each neighbor (in distance order) against the closer neighbors selected so far, @@ -336,13 +581,13 @@ private boolean[] selectAndLinkDiverse( float cScore = candidates.scores()[i]; assert cNode <= hnsw.maxNodeId(); if (diversityCheck(cNode, cScore, neighbors)) { - mask[i] = true; + selected[i] = true; // here we don't need to lock, because there's no incoming link so no others is able to // discover this node such that no others will modify this neighbor array as well neighbors.addInOrder(cNode, cScore); } } - return mask; + return selected; } private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) { diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index 7d1ed069c298..a82eeb6f0523 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -32,6 +32,7 @@ * @lucene.internal */ public class NeighborArray { + private final boolean scoresDescOrder; private int size; private final float[] scores; @@ -45,6 +46,13 @@ public NeighborArray(int maxSize, boolean descOrder) { this.scoresDescOrder = descOrder; } + public boolean isScoresDescOrder() { + return scoresDescOrder; + } + + public int getMaxSize() { + return nodes.length; + } /** * Add a new node to the NeighborArray. The new node must be worse than all previously stored * nodes. This cannot be called after {@link #addOutOfOrder(int, float)} @@ -80,6 +88,7 @@ public void addOutOfOrder(int newNode, float newScore) { size++; } + /** * In addition to {@link #addOutOfOrder(int, float)}, this function will also remove the * least-diverse node if the node array is full after insertion @@ -88,17 +97,21 @@ public void addOutOfOrder(int newNode, float newScore) { * multiple threads while other add method is only supposed to be called by one thread. * * @param nodeId node Id of the owner of this NeighbourArray + * @return node Id of removed node or -1 if no node was removed. */ - public void addAndEnsureDiversity( + public int addAndEnsureDiversity( int newNode, float newScore, int nodeId, RandomVectorScorerSupplier scorerSupplier) throws IOException { addOutOfOrder(newNode, newScore); if (size < nodes.length) { - return; + return -1; } // we're oversize, need to do diversity check and pop out the least diverse neighbour - removeIndex(findWorstNonDiverse(nodeId, scorerSupplier)); + int indexToRemove = findWorstNonDiverse(nodeId, scorerSupplier); + int nodeRemoved = nodes[indexToRemove]; + removeIndex(indexToRemove); assert size == nodes.length - 1; + return nodeRemoved; } /** @@ -290,4 +303,24 @@ private boolean isWorstNonDiverse( } return false; } + + public void removeNode(int nodeId) { +// System.out.println("size = " + this.size() + " node.length = " + node.length); + int indexToRemove = -1; + for (int i = 0; i < this.size(); i++) { + if (nodes[i] == nodeId) { + indexToRemove = i; + break; + } + } + +// assert indexToRemove != -1; + if (indexToRemove == -1) { +// System.out.println("ERRORR1 : did not find nodeId = " + nodeId ); + throw new IllegalStateException("Did not find the nodeId : " + nodeId); + } else { + removeIndex(indexToRemove); + } + } + }