Skip to content

Commit

Permalink
Adding heuristics in hnsw paper and improving proposed heuristic for …
Browse files Browse the repository at this point in the history
…experimentations.
  • Loading branch information
Nitiraj Rathore authored and nitirajrathore committed Feb 1, 2024
1 parent 78b4f75 commit d43b15e
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 7 deletions.
253 changes: 249 additions & 4 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static java.lang.Math.log;

import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
Expand Down Expand Up @@ -252,7 +253,20 @@ to the newly introduced levels (repeating step 2,3 for new levels) and again try

// then do connections from bottom up
for (int i = 0; i < scratchPerLevel.length; i++) {
addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]);
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); // baseline : similar to false, false, false below.
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false, false, false); // baseline_equivalent
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false, false, true); // baseline with remove other half
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, false, false, false); // exp-1 extendCandidates = true
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, false, false); // exp-2 with keep-pruned =true
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, true, false); // exp-3 with keep pruned till half max-conn
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, true, true); // exp-3 with keep pruned till half max-conn and remove other half
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true, false, false);
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true, true, false);

// new heuristic
// addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false); // new heuristic without remove other half
// addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], true, false); // new heuristic with remove other half
addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true); // new heuristic with honour-max-conn true
}
lowestUnsetLevel += scratchPerLevel.length;
assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
Expand Down Expand Up @@ -290,6 +304,233 @@ private long printGraphBuildStatus(int node, long start, long t) {
return now;
}

/**
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors, int level, RandomVectorScorer scorer, boolean removeRandom) throws IOException {
int[] uncheckedIndexes = neighbors.sort(scorer); // what does 'unchecked node' means? What check? What
int maxCommonConnectionCount = 0;
int maxcccIndex = -1;
int maxConnectionIndex = -1;
int maxConnectionCount = 2; // has atleast 2 connections
for (int i = neighbors.size() - 1; i >= 0 ; i--) {
int currNode = neighbors.nodes()[i];
float currNodeScore = neighbors.scores()[i];
NeighborArray currNodeNeighbours = hnsw.getNeighbors(level, currNode);
NeighborArray commonNeighbours = findCommon(neighbors, currNodeNeighbours);

if (commonNeighbours.size() > maxCommonConnectionCount) {
maxCommonConnectionCount = commonNeighbours.size();
maxcccIndex = i;
}
if (currNodeNeighbours.size() >= maxConnectionCount) {
maxConnectionCount = currNodeNeighbours.size();
maxConnectionIndex = i;
}

for (int j = 0; j < commonNeighbours.size(); j++) {
if (commonNeighbours.scores()[j] > currNodeScore) {
return i; // currNode is non-diverse as it's score with another neighbour is higher
}
}
}
if (maxcccIndex != -1) {
return maxcccIndex;
} else if (maxConnectionIndex != -1) {
return maxConnectionIndex;
} else {
if (removeRandom == false) {
return -1;
} else {
return random.nextInt(neighbors.size());
}
}
}

public int addAndEnsureConnectedDiversity(int newNode, float newScore, int nodeId, NeighborArray neighbours,
int level, int maxConnOnLevel, boolean removeRandom)
throws IOException {
neighbours.addOutOfOrder(newNode, newScore);
if (neighbours.size() <= maxConnOnLevel) {
return -1; // none removed
}
RandomVectorScorer scorer = scorerSupplier.scorer(nodeId);
int indexToRemove = findWorstNonDiverse(neighbours, level, scorer, removeRandom);
if (indexToRemove != -1) {
int nodeRemoved = neighbours.nodes()[indexToRemove];
neighbours.removeIndex(indexToRemove);
return nodeRemoved;
} else {
if (removeRandom == true) {
throw new IllegalStateException("If remove random is set we should not have got -1 for remove index.");
} else {
return -1; // no diverse found hence not removed
}
}
}

private NeighborArray findCommon(NeighborArray neighbors, NeighborArray currNodeNeighbours) {
NeighborArray common = new NeighborArray(Math.max(currNodeNeighbours.size(), neighbors.size()), true);
for (int i = 0; i < neighbors.size(); i++) {
for (int j = 0; j < currNodeNeighbours.size(); j++) {
if (neighbors.nodes()[i] == currNodeNeighbours.nodes()[j]) {
common.addOutOfOrder(currNodeNeighbours.nodes()[j], currNodeNeighbours.scores()[j]);
}
}
}

return common;
}


private void addDiverseNeighborsNewHeuristic(int level, int incomingNode, NeighborArray candidates, boolean removeOtherHalf,
boolean honourMaxConn)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
* since the node is new and has no prior neighbors).
*/
NeighborArray neighbors = hnsw.getNeighbors(level, incomingNode);
assert neighbors.size() == 0; // new node

int maxConnOnLevel = level == 0 ? M * 2 : M;

boolean[] mask = selectAll(neighbors, candidates, maxConnOnLevel);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
int size = candidates.size();
for (int i = 0; i < size; i++) {
if (mask[i] == false) {
continue;
}

int nbr = candidates.nodes()[i];
int nodeRemoved = -1;
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.rwlock.writeLock().lock();
try {
nodeRemoved = addAndEnsureConnectedDiversity(incomingNode, candidates.scores()[i], nbr, nbrsOfNbr, level, maxConnOnLevel, honourMaxConn);
} finally {
nbrsOfNbr.rwlock.writeLock().unlock();
}

if (removeOtherHalf && nodeRemoved != -1) {
// it is fine to remove it even from current node as we are iterating over candidates and not over neighbours
// Also, there is no chance of deadlock as we are not holding more than one lock at a time.
removeOtherHalf(nodeRemoved, nbr, level);
}
}

}

private boolean[] selectAll( NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) {
boolean[] selected = new boolean[candidates.size()];

for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
int cNode = candidates.nodes()[i];
float cScore = candidates.scores()[i];
selected[i] = true;
neighbors.addInOrder(cNode, cScore);
}

return selected;
}

private void removeOtherHalf(int fromNode, int nbr, int level) {
NeighborArray neighbors = hnsw.getNeighbors(level, fromNode);
neighbors.rwlock.writeLock().lock();
neighbors.removeNode(nbr);
neighbors.rwlock.writeLock().unlock();
}


private void addDiverseNeighbors(int level, int node, NeighborArray candidates,
boolean extendCandidates, boolean keepPrunedConnections,
boolean keepHalfPrunedConnection, boolean removeOtherHalf) throws IOException {
NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node

int maxConnOnLevel = level == 0 ? M * 2 : M;
NeighborArray originalCandidates = candidates;
RandomVectorScorer scorer = scorerSupplier.scorer(node);

if (extendCandidates) {
HashSet<Integer> addedNodes = new HashSet<>();

for (int i = 0; i < originalCandidates.size(); i++) {
int cand = originalCandidates.nodes()[i];
addedNodes.add(cand);
addAllNeighbours(cand, level, addedNodes);
}

candidates = new NeighborArray(addedNodes.size(), originalCandidates.isScoresDescOrder());
for (int cand : addedNodes) {
candidates.addOutOfOrder(cand, Float.NaN);
}

candidates.sort(scorer);
}


boolean[] selected = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel);

if (keepPrunedConnections) {
int maskedIndex = selected.length - 1; // start from end as the highest scoring candidates are at the end
int maxConn = maxConnOnLevel;
if (keepHalfPrunedConnection) {
maxConn = maxConnOnLevel/2;
}
while (neighbors.size() < maxConn && maskedIndex >= 0) {
if (selected[maskedIndex] == false) {
neighbors.addOutOfOrder(candidates.nodes()[maskedIndex], candidates.scores()[maskedIndex]);
selected[maskedIndex] = true;
}
maskedIndex--;
}
neighbors.sort(scorer);
}

// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
// NOTE: here we're using candidates and mask but not the neighbour array because once we have
// added incoming link there will be possibilities of this node being discovered and neighbour
// array being modified. So using local candidates and mask is a safer option.
int nodeRemoved = -1;
for (int i = 0; i < candidates.size(); i++) {
if (selected[i] == false) {
continue;
}
int nbr = candidates.nodes()[i];
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.rwlock.writeLock().lock();
try {
nodeRemoved = nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier);
} finally {
nbrsOfNbr.rwlock.writeLock().unlock();
}

if (removeOtherHalf && nodeRemoved != -1) {
removeOtherHalf(nodeRemoved, nbr, level);
}
}
}

/**
* Adds all neighbours of `node` into the `addedNodes` set
*
* @param node
* @param addedNodes
*/
private void addAllNeighbours(int node, int level, HashSet<Integer> addedNodes) {
NeighborArray neighbors = hnsw.getNeighbors(level, node);

// TODO: I think we need to guard with lock here as the `neighbors` array might get modified while we are iterating.
for (int i = 0; i < neighbors.size(); i++) {
addedNodes.add(neighbors.nodes()[i]);
}
}

private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
Expand Down Expand Up @@ -327,7 +568,11 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
*/
private boolean[] selectAndLinkDiverse(
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
boolean[] mask = new boolean[candidates.size()];
boolean[] selected = new boolean[candidates.size()];
return selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, selected);
}
private boolean[] selectAndLinkDiverse(
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel, boolean[] selected) throws IOException {
// Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
Expand All @@ -336,13 +581,13 @@ private boolean[] selectAndLinkDiverse(
float cScore = candidates.scores()[i];
assert cNode <= hnsw.maxNodeId();
if (diversityCheck(cNode, cScore, neighbors)) {
mask[i] = true;
selected[i] = true;
// here we don't need to lock, because there's no incoming link so no others is able to
// discover this node such that no others will modify this neighbor array as well
neighbors.addInOrder(cNode, cScore);
}
}
return mask;
return selected;
}

private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* @lucene.internal
*/
public class NeighborArray {

private final boolean scoresDescOrder;
private int size;
private final float[] scores;
Expand All @@ -45,6 +46,13 @@ public NeighborArray(int maxSize, boolean descOrder) {
this.scoresDescOrder = descOrder;
}

public boolean isScoresDescOrder() {
return scoresDescOrder;
}

public int getMaxSize() {
return nodes.length;
}
/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
* nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
Expand Down Expand Up @@ -80,6 +88,7 @@ public void addOutOfOrder(int newNode, float newScore) {
size++;
}


/**
* In addition to {@link #addOutOfOrder(int, float)}, this function will also remove the
* least-diverse node if the node array is full after insertion
Expand All @@ -88,17 +97,21 @@ public void addOutOfOrder(int newNode, float newScore) {
* multiple threads while other add method is only supposed to be called by one thread.
*
* @param nodeId node Id of the owner of this NeighbourArray
* @return node Id of removed node or -1 if no node was removed.
*/
public void addAndEnsureDiversity(
public int addAndEnsureDiversity(
int newNode, float newScore, int nodeId, RandomVectorScorerSupplier scorerSupplier)
throws IOException {
addOutOfOrder(newNode, newScore);
if (size < nodes.length) {
return;
return -1;
}
// we're oversize, need to do diversity check and pop out the least diverse neighbour
removeIndex(findWorstNonDiverse(nodeId, scorerSupplier));
int indexToRemove = findWorstNonDiverse(nodeId, scorerSupplier);
int nodeRemoved = nodes[indexToRemove];
removeIndex(indexToRemove);
assert size == nodes.length - 1;
return nodeRemoved;
}

/**
Expand Down Expand Up @@ -290,4 +303,24 @@ private boolean isWorstNonDiverse(
}
return false;
}

public void removeNode(int nodeId) {
// System.out.println("size = " + this.size() + " node.length = " + node.length);
int indexToRemove = -1;
for (int i = 0; i < this.size(); i++) {
if (nodes[i] == nodeId) {
indexToRemove = i;
break;
}
}

// assert indexToRemove != -1;
if (indexToRemove == -1) {
// System.out.println("ERRORR1 : did not find nodeId = " + nodeId );
throw new IllegalStateException("Did not find the nodeId : " + nodeId);
} else {
removeIndex(indexToRemove);
}
}

}

0 comments on commit d43b15e

Please sign in to comment.