Skip to content

Commit

Permalink
Optimize HNSW diversity calculation (#12235)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaih authored May 16, 2023
1 parent 0e172b0 commit 8af3058
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 86 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Optimizations

* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai)

* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public List<TermAndBoost> getSynonyms(
SIMILARITY_FUNCTION,
hnswGraph,
null,
word2VecModel.size());
Integer.MAX_VALUE);

int size = synonyms.size();
for (int i = 0; i < size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOEx
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node()[i];
Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
nbrNbr.add(node, neighbors.score()[i]);
if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr);
Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr);
nbrsOfNbr.add(node, neighbors.score()[i]);
if (nbrsOfNbr.size() > maxConn) {
diversityUpdate(nbrsOfNbr);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
nbrNbr.add(node, neighbors.score[i]);
if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr);
Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.add(node, neighbors.score[i]);
if (nbrsOfNbr.size() > maxConn) {
diversityUpdate(nbrsOfNbr);
}
}
}
Expand Down
124 changes: 96 additions & 28 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ private void initializeFromGraph(
case BYTE -> this.similarityFunction.compare(
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
};
newNeighbors.insertSorted(newNeighbor, score);
// we are not sure whether the previous graph contains
// unchecked nodes, so we have to assume they're all unchecked
newNeighbors.addOutOfOrder(newNeighbor, score);
}
}
}
Expand Down Expand Up @@ -314,11 +316,11 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
nbrNbr.insertSorted(node, neighbors.score[i]);
if (nbrNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrNbr);
nbrNbr.removeIndex(indexToRemove);
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
if (nbrsOfNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
nbrsOfNbr.removeIndex(indexToRemove);
}
}
}
Expand All @@ -333,7 +335,7 @@ private void selectAndLinkDiverse(
float cScore = candidates.score[i];
assert cNode < hnsw.size();
if (diversityCheck(cNode, cScore, neighbors)) {
neighbors.add(cNode, cScore);
neighbors.addInOrder(cNode, cScore);
}
}
}
Expand All @@ -345,7 +347,7 @@ private void popToScratch(NeighborQueue candidates) {
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float maxSimilarity = candidates.topScore();
scratch.add(candidates.pop(), maxSimilarity);
scratch.addInOrder(candidates.pop(), maxSimilarity);
}
}

Expand Down Expand Up @@ -400,50 +402,116 @@ private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
int[] uncheckedIndexes = neighbors.sort();
if (uncheckedIndexes == null) {
// all nodes are checked, we will directly return the most distant one
return neighbors.size() - 1;
}
int uncheckedCursor = uncheckedIndexes.length - 1;
for (int i = neighbors.size() - 1; i > 0; i--) {
if (isWorstNonDiverse(i, neighbors)) {
if (uncheckedCursor < 0) {
// no unchecked node left
break;
}
if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
return i;
}
if (i == uncheckedIndexes[uncheckedCursor]) {
uncheckedCursor--;
}
}
return neighbors.size() - 1;
}

private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
private boolean isWorstNonDiverse(
int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
candidateIndex,
(byte[]) vectors.vectorValue(candidateNode),
neighbors,
uncheckedIndexes,
uncheckedCursor);
case FLOAT32 -> isWorstNonDiverse(
candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
candidateIndex,
(float[]) vectors.vectorValue(candidateNode),
neighbors,
uncheckedIndexes,
uncheckedCursor);
};
}

private boolean isWorstNonDiverse(
int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
int candidateIndex,
float[] candidateVector,
NeighborArray neighbors,
int[] uncheckedIndexes,
int uncheckedCursor)
throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
// the candidate itself is unchecked
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
} else {
// else we just need to make sure candidate does not violate diversity with the (newly
// inserted) unchecked nodes
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
for (int i = uncheckedCursor; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector,
(float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
}
return false;
}

private boolean isWorstNonDiverse(
int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
int candidateIndex,
byte[] candidateVector,
NeighborArray neighbors,
int[] uncheckedIndexes,
int uncheckedCursor)
throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
// the candidate itself is unchecked
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
} else {
// else we just need to make sure candidate does not violate diversity with the (newly
// inserted) unchecked nodes
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
for (int i = uncheckedCursor; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector,
(byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
}
return false;
Expand Down
88 changes: 72 additions & 16 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class NeighborArray {

float[] score;
int[] node;
private int sortedNodeSize;

public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];
Expand All @@ -43,9 +44,10 @@ public NeighborArray(int maxSize, boolean descOrder) {

/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
* nodes.
* nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
*/
public void add(int newNode, float newScore) {
public void addInOrder(int newNode, float newScore) {
assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder";
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
Expand All @@ -59,23 +61,72 @@ public void add(int newNode, float newScore) {
node[size] = newNode;
score[size] = newScore;
++size;
++sortedNodeSize;
}

/** Add a new node to the NeighborArray into a correct sort position according to its score. */
public void insertSorted(int newNode, float newScore) {
/** Add node and score but do not insert as sorted */
public void addOutOfOrder(int newNode, float newScore) {
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
node[size] = newNode;
score[size] = newScore;
size++;
}

/**
* Sort the array according to scores, and return the sorted indexes of previous unsorted nodes
* (unchecked nodes)
*
* @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
* already fully sorted
*/
public int[] sort() {
if (size == sortedNodeSize) {
// all nodes checked and sorted
return null;
}
assert sortedNodeSize < size;
int[] uncheckedIndexes = new int[size - sortedNodeSize];
int count = 0;
while (sortedNodeSize != size) {
uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
for (int i = 0; i < count; i++) {
if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
// the previous inserted nodes has been shifted
uncheckedIndexes[i]++;
}
}
count++;
}
Arrays.sort(uncheckedIndexes);
return uncheckedIndexes;
}

/** insert the first unsorted node into its sorted position */
private int insertSortedInternal() {
assert sortedNodeSize < size : "Call this method only when there's unsorted node";
int tmpNode = node[sortedNodeSize];
float tmpScore = score[sortedNodeSize];
int insertionPoint =
scoresDescOrder
? descSortFindRightMostInsertionPoint(newScore)
: ascSortFindRightMostInsertionPoint(newScore);
System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
node[insertionPoint] = newNode;
score[insertionPoint] = newScore;
++size;
? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
: ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize);
System.arraycopy(
node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint);
System.arraycopy(
score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint);
node[insertionPoint] = tmpNode;
score[insertionPoint] = tmpScore;
++sortedNodeSize;
return insertionPoint;
}

/** This method is for test only. */
void insertSorted(int newNode, float newScore) {
addOutOfOrder(newNode, newScore);
insertSortedInternal();
}

public int size() {
Expand All @@ -97,15 +148,20 @@ public float[] score() {

public void clear() {
size = 0;
sortedNodeSize = 0;
}

public void removeLast() {
size--;
sortedNodeSize = Math.min(sortedNodeSize, size);
}

public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
if (idx < sortedNodeSize) {
sortedNodeSize--;
}
size--;
}

Expand All @@ -114,11 +170,11 @@ public String toString() {
return "NeighborArray[" + size + "]";
}

private int ascSortFindRightMostInsertionPoint(float newScore) {
int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
private int ascSortFindRightMostInsertionPoint(float newScore, int bound) {
int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore);
if (insertionPoint >= 0) {
// find the right most position with the same score
while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
insertionPoint++;
}
insertionPoint++;
Expand All @@ -128,9 +184,9 @@ private int ascSortFindRightMostInsertionPoint(float newScore) {
return insertionPoint;
}

private int descSortFindRightMostInsertionPoint(float newScore) {
private int descSortFindRightMostInsertionPoint(float newScore, int bound) {
int start = 0;
int end = size - 1;
int end = bound - 1;
while (start <= end) {
int mid = (start + end) / 2;
if (score[mid] < newScore) end = mid - 1;
Expand Down
Loading

0 comments on commit 8af3058

Please sign in to comment.