Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated heuristic to remove non diverse edges keeping overall graph c… #12783

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 249 additions & 4 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static java.lang.Math.log;

import java.io.IOException;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
Expand Down Expand Up @@ -252,7 +253,20 @@ to the newly introduced levels (repeating step 2,3 for new levels) and again try

// then do connections from bottom up
for (int i = 0; i < scratchPerLevel.length; i++) {
addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]);
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); // baseline : similar to false, false, false below.
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false, false, false); // baseline_equivalent
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false, false, true); // baseline with remove other half
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, false, false, false); // exp-1 extendCandidates = true
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, false, false); // exp-2 with keep-pruned =true
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, true, false); // exp-3 with keep pruned till half max-conn
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], false, true, true, true); // exp-3 with keep pruned till half max-conn and remove other half
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true, false, false);
// addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true, true, false);

// new heuristic
// addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], false, false); // new heuristic without remove other half
// addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], true, false); // new heuristic with remove other half
addDiverseNeighborsNewHeuristic(i + lowestUnsetLevel, node, scratchPerLevel[i], true, true); // new heuristic with honour-max-conn true
}
lowestUnsetLevel += scratchPerLevel.length;
assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1;
Expand Down Expand Up @@ -290,6 +304,233 @@ private long printGraphBuildStatus(int node, long start, long t) {
return now;
}

/**
* Find first non-diverse neighbour among the list of neighbors starting from the most distant
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors, int level, RandomVectorScorer scorer, boolean removeRandom) throws IOException {
int[] uncheckedIndexes = neighbors.sort(scorer); // what does 'unchecked node' means? What check? What
int maxCommonConnectionCount = 0;
int maxcccIndex = -1;
int maxConnectionIndex = -1;
int maxConnectionCount = 2; // has atleast 2 connections
for (int i = neighbors.size() - 1; i >= 0 ; i--) {
int currNode = neighbors.nodes()[i];
float currNodeScore = neighbors.scores()[i];
NeighborArray currNodeNeighbours = hnsw.getNeighbors(level, currNode);
NeighborArray commonNeighbours = findCommon(neighbors, currNodeNeighbours);

if (commonNeighbours.size() > maxCommonConnectionCount) {
maxCommonConnectionCount = commonNeighbours.size();
maxcccIndex = i;
}
if (currNodeNeighbours.size() >= maxConnectionCount) {
maxConnectionCount = currNodeNeighbours.size();
maxConnectionIndex = i;
}

for (int j = 0; j < commonNeighbours.size(); j++) {
if (commonNeighbours.scores()[j] > currNodeScore) {
return i; // currNode is non-diverse as it's score with another neighbour is higher
}
}
}
if (maxcccIndex != -1) {
return maxcccIndex;
} else if (maxConnectionIndex != -1) {
return maxConnectionIndex;
} else {
if (removeRandom == false) {
return -1;
} else {
return random.nextInt(neighbors.size());
}
}
}

public int addAndEnsureConnectedDiversity(int newNode, float newScore, int nodeId, NeighborArray neighbours,
int level, int maxConnOnLevel, boolean removeRandom)
throws IOException {
neighbours.addOutOfOrder(newNode, newScore);
if (neighbours.size() <= maxConnOnLevel) {
return -1; // none removed
}
RandomVectorScorer scorer = scorerSupplier.scorer(nodeId);
int indexToRemove = findWorstNonDiverse(neighbours, level, scorer, removeRandom);
if (indexToRemove != -1) {
int nodeRemoved = neighbours.nodes()[indexToRemove];
neighbours.removeIndex(indexToRemove);
return nodeRemoved;
} else {
if (removeRandom == true) {
throw new IllegalStateException("If remove random is set we should not have got -1 for remove index.");
} else {
return -1; // no diverse found hence not removed
}
}
}

private NeighborArray findCommon(NeighborArray neighbors, NeighborArray currNodeNeighbours) {
NeighborArray common = new NeighborArray(Math.max(currNodeNeighbours.size(), neighbors.size()), true);
for (int i = 0; i < neighbors.size(); i++) {
for (int j = 0; j < currNodeNeighbours.size(); j++) {
if (neighbors.nodes()[i] == currNodeNeighbours.nodes()[j]) {
common.addOutOfOrder(currNodeNeighbours.nodes()[j], currNodeNeighbours.scores()[j]);
}
}
}

return common;
}


private void addDiverseNeighborsNewHeuristic(int level, int incomingNode, NeighborArray candidates, boolean removeOtherHalf,
boolean honourMaxConn)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
* is closer to target than it is to any of the already-selected neighbors (ie selected in this method,
* since the node is new and has no prior neighbors).
*/
NeighborArray neighbors = hnsw.getNeighbors(level, incomingNode);
assert neighbors.size() == 0; // new node

int maxConnOnLevel = level == 0 ? M * 2 : M;

boolean[] mask = selectAll(neighbors, candidates, maxConnOnLevel);
// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
int size = candidates.size();
for (int i = 0; i < size; i++) {
if (mask[i] == false) {
continue;
}

int nbr = candidates.nodes()[i];
int nodeRemoved = -1;
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.rwlock.writeLock().lock();
try {
nodeRemoved = addAndEnsureConnectedDiversity(incomingNode, candidates.scores()[i], nbr, nbrsOfNbr, level, maxConnOnLevel, honourMaxConn);
} finally {
nbrsOfNbr.rwlock.writeLock().unlock();
}

if (removeOtherHalf && nodeRemoved != -1) {
// it is fine to remove it even from current node as we are iterating over candidates and not over neighbours
// Also, there is no chance of deadlock as we are not holding more than one lock at a time.
removeOtherHalf(nodeRemoved, nbr, level);
}
}

}

private boolean[] selectAll( NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) {
boolean[] selected = new boolean[candidates.size()];

for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
int cNode = candidates.nodes()[i];
float cScore = candidates.scores()[i];
selected[i] = true;
neighbors.addInOrder(cNode, cScore);
}

return selected;
}

private void removeOtherHalf(int fromNode, int nbr, int level) {
NeighborArray neighbors = hnsw.getNeighbors(level, fromNode);
neighbors.rwlock.writeLock().lock();
neighbors.removeNode(nbr);
neighbors.rwlock.writeLock().unlock();
}


private void addDiverseNeighbors(int level, int node, NeighborArray candidates,
boolean extendCandidates, boolean keepPrunedConnections,
boolean keepHalfPrunedConnection, boolean removeOtherHalf) throws IOException {
NeighborArray neighbors = hnsw.getNeighbors(level, node);
assert neighbors.size() == 0; // new node

int maxConnOnLevel = level == 0 ? M * 2 : M;
NeighborArray originalCandidates = candidates;
RandomVectorScorer scorer = scorerSupplier.scorer(node);

if (extendCandidates) {
HashSet<Integer> addedNodes = new HashSet<>();

for (int i = 0; i < originalCandidates.size(); i++) {
int cand = originalCandidates.nodes()[i];
addedNodes.add(cand);
addAllNeighbours(cand, level, addedNodes);
}

candidates = new NeighborArray(addedNodes.size(), originalCandidates.isScoresDescOrder());
for (int cand : addedNodes) {
candidates.addOutOfOrder(cand, Float.NaN);
}

candidates.sort(scorer);
}


boolean[] selected = selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel);

if (keepPrunedConnections) {
int maskedIndex = selected.length - 1; // start from end as the highest scoring candidates are at the end
int maxConn = maxConnOnLevel;
if (keepHalfPrunedConnection) {
maxConn = maxConnOnLevel/2;
}
while (neighbors.size() < maxConn && maskedIndex >= 0) {
if (selected[maskedIndex] == false) {
neighbors.addOutOfOrder(candidates.nodes()[maskedIndex], candidates.scores()[maskedIndex]);
selected[maskedIndex] = true;
}
maskedIndex--;
}
neighbors.sort(scorer);
}

// Link the selected nodes to the new node, and the new node to the selected nodes (again
// applying diversity heuristic)
// NOTE: here we're using candidates and mask but not the neighbour array because once we have
// added incoming link there will be possibilities of this node being discovered and neighbour
// array being modified. So using local candidates and mask is a safer option.
int nodeRemoved = -1;
for (int i = 0; i < candidates.size(); i++) {
if (selected[i] == false) {
continue;
}
int nbr = candidates.nodes()[i];
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.rwlock.writeLock().lock();
try {
nodeRemoved = nbrsOfNbr.addAndEnsureDiversity(node, candidates.scores()[i], nbr, scorerSupplier);
} finally {
nbrsOfNbr.rwlock.writeLock().unlock();
}

if (removeOtherHalf && nodeRemoved != -1) {
removeOtherHalf(nodeRemoved, nbr, level);
}
}
}

/**
* Adds all neighbours of `node` into the `addedNodes` set
*
* @param node
* @param addedNodes
*/
private void addAllNeighbours(int node, int level, HashSet<Integer> addedNodes) {
NeighborArray neighbors = hnsw.getNeighbors(level, node);

// TODO: I think we need to guard with lock here as the `neighbors` array might get modified while we are iterating.
for (int i = 0; i < neighbors.size(); i++) {
addedNodes.add(neighbors.nodes()[i]);
}
}

private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
throws IOException {
/* For each of the beamWidth nearest candidates (going from best to worst), select it only if it
Expand Down Expand Up @@ -327,7 +568,11 @@ private void addDiverseNeighbors(int level, int node, NeighborArray candidates)
*/
private boolean[] selectAndLinkDiverse(
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel) throws IOException {
boolean[] mask = new boolean[candidates.size()];
boolean[] selected = new boolean[candidates.size()];
return selectAndLinkDiverse(neighbors, candidates, maxConnOnLevel, selected);
}
private boolean[] selectAndLinkDiverse(
NeighborArray neighbors, NeighborArray candidates, int maxConnOnLevel, boolean[] selected) throws IOException {
// Select the best maxConnOnLevel neighbors of the new node, applying the diversity heuristic
for (int i = candidates.size() - 1; neighbors.size() < maxConnOnLevel && i >= 0; i--) {
// compare each neighbor (in distance order) against the closer neighbors selected so far,
Expand All @@ -336,13 +581,13 @@ private boolean[] selectAndLinkDiverse(
float cScore = candidates.scores()[i];
assert cNode <= hnsw.maxNodeId();
if (diversityCheck(cNode, cScore, neighbors)) {
mask[i] = true;
selected[i] = true;
// here we don't need to lock, because there's no incoming link so no others is able to
// discover this node such that no others will modify this neighbor array as well
neighbors.addInOrder(cNode, cScore);
}
}
return mask;
return selected;
}

private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborArray scratch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
* @lucene.internal
*/
public class NeighborArray {

private final boolean scoresDescOrder;
private int size;
private final float[] scores;
Expand All @@ -45,6 +46,13 @@ public NeighborArray(int maxSize, boolean descOrder) {
this.scoresDescOrder = descOrder;
}

public boolean isScoresDescOrder() {
return scoresDescOrder;
}

public int getMaxSize() {
return nodes.length;
}
/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
* nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
Expand Down Expand Up @@ -80,6 +88,7 @@ public void addOutOfOrder(int newNode, float newScore) {
size++;
}


/**
* In addition to {@link #addOutOfOrder(int, float)}, this function will also remove the
* least-diverse node if the node array is full after insertion
Expand All @@ -88,17 +97,21 @@ public void addOutOfOrder(int newNode, float newScore) {
* multiple threads while other add method is only supposed to be called by one thread.
*
* @param nodeId node Id of the owner of this NeighbourArray
* @return node Id of removed node or -1 if no node was removed.
*/
public void addAndEnsureDiversity(
public int addAndEnsureDiversity(
int newNode, float newScore, int nodeId, RandomVectorScorerSupplier scorerSupplier)
throws IOException {
addOutOfOrder(newNode, newScore);
if (size < nodes.length) {
return;
return -1;
}
// we're oversize, need to do diversity check and pop out the least diverse neighbour
removeIndex(findWorstNonDiverse(nodeId, scorerSupplier));
int indexToRemove = findWorstNonDiverse(nodeId, scorerSupplier);
int nodeRemoved = nodes[indexToRemove];
removeIndex(indexToRemove);
assert size == nodes.length - 1;
return nodeRemoved;
}

/**
Expand Down Expand Up @@ -290,4 +303,24 @@ private boolean isWorstNonDiverse(
}
return false;
}

public void removeNode(int nodeId) {
// System.out.println("size = " + this.size() + " node.length = " + node.length);
int indexToRemove = -1;
for (int i = 0; i < this.size(); i++) {
if (nodes[i] == nodeId) {
indexToRemove = i;
break;
}
}

// assert indexToRemove != -1;
if (indexToRemove == -1) {
// System.out.println("ERRORR1 : did not find nodeId = " + nodeId );
throw new IllegalStateException("Did not find the nodeId : " + nodeId);
} else {
removeIndex(indexToRemove);
}
}

}
Loading