Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize HNSW diversity calculation #12235

Merged
merged 11 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Optimizations

* GITHUB#12286 Toposort use iterator to avoid stackoverflow. (Tang Donghai)

* GITHUB#12235: Optimize HNSW diversity calculation. (Patrick Zhai)

Bug Fixes
---------------------
(No changes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public List<TermAndBoost> getSynonyms(
SIMILARITY_FUNCTION,
hnswGraph,
null,
word2VecModel.size());
Integer.MAX_VALUE);

int size = synonyms.size();
for (int i = 0; i < size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ private void addDiverseNeighbors(int node, NeighborQueue candidates) throws IOEx
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node()[i];
Lucene90NeighborArray nbrNbr = hnsw.getNeighbors(nbr);
nbrNbr.add(node, neighbors.score()[i]);
if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr);
Lucene90NeighborArray nbrsOfNbr = hnsw.getNeighbors(nbr);
nbrsOfNbr.add(node, neighbors.score()[i]);
if (nbrsOfNbr.size() > maxConn) {
diversityUpdate(nbrsOfNbr);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
Lucene91NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
nbrNbr.add(node, neighbors.score[i]);
if (nbrNbr.size() > maxConn) {
diversityUpdate(nbrNbr);
Lucene91NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.add(node, neighbors.score[i]);
if (nbrsOfNbr.size() > maxConn) {
diversityUpdate(nbrsOfNbr);
}
}
}
Expand Down
124 changes: 96 additions & 28 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ private void initializeFromGraph(
case BYTE -> this.similarityFunction.compare(
binaryValue, (byte[]) vectorsCopy.vectorValue(newNeighbor));
};
newNeighbors.insertSorted(newNeighbor, score);
// we are not sure whether the previous graph contains
// unchecked nodes, so we have to assume they're all unchecked
newNeighbors.addOutOfOrder(newNeighbor, score);
}
}
}
Expand Down Expand Up @@ -314,11 +316,11 @@ private void addDiverseNeighbors(int level, int node, NeighborQueue candidates)
int size = neighbors.size();
for (int i = 0; i < size; i++) {
int nbr = neighbors.node[i];
NeighborArray nbrNbr = hnsw.getNeighbors(level, nbr);
nbrNbr.insertSorted(node, neighbors.score[i]);
if (nbrNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrNbr);
nbrNbr.removeIndex(indexToRemove);
NeighborArray nbrsOfNbr = hnsw.getNeighbors(level, nbr);
nbrsOfNbr.addOutOfOrder(node, neighbors.score[i]);
if (nbrsOfNbr.size() > maxConnOnLevel) {
int indexToRemove = findWorstNonDiverse(nbrsOfNbr);
nbrsOfNbr.removeIndex(indexToRemove);
}
}
}
Expand All @@ -333,7 +335,7 @@ private void selectAndLinkDiverse(
float cScore = candidates.score[i];
assert cNode < hnsw.size();
if (diversityCheck(cNode, cScore, neighbors)) {
neighbors.add(cNode, cScore);
neighbors.addInOrder(cNode, cScore);
}
}
}
Expand All @@ -345,7 +347,7 @@ private void popToScratch(NeighborQueue candidates) {
// sorted from worst to best
for (int i = 0; i < candidateCount; i++) {
float maxSimilarity = candidates.topScore();
scratch.add(candidates.pop(), maxSimilarity);
scratch.addInOrder(candidates.pop(), maxSimilarity);
}
}

Expand Down Expand Up @@ -400,50 +402,116 @@ private boolean isDiverse(byte[] candidate, NeighborArray neighbors, float score
* neighbours
*/
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
int[] uncheckedIndexes = neighbors.sort();
if (uncheckedIndexes == null) {
// all nodes are checked, we will directly return the most distant one
return neighbors.size() - 1;
}
int uncheckedCursor = uncheckedIndexes.length - 1;
for (int i = neighbors.size() - 1; i > 0; i--) {
if (isWorstNonDiverse(i, neighbors)) {
if (uncheckedCursor < 0) {
// no unchecked node left
break;
}
if (isWorstNonDiverse(i, neighbors, uncheckedIndexes, uncheckedCursor)) {
return i;
}
if (i == uncheckedIndexes[uncheckedCursor]) {
uncheckedCursor--;
}
}
return neighbors.size() - 1;
}

private boolean isWorstNonDiverse(int candidateIndex, NeighborArray neighbors)
private boolean isWorstNonDiverse(
int candidateIndex, NeighborArray neighbors, int[] uncheckedIndexes, int uncheckedCursor)
throws IOException {
int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidateIndex, (byte[]) vectors.vectorValue(candidateNode), neighbors);
candidateIndex,
(byte[]) vectors.vectorValue(candidateNode),
neighbors,
uncheckedIndexes,
uncheckedCursor);
case FLOAT32 -> isWorstNonDiverse(
candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
candidateIndex,
(float[]) vectors.vectorValue(candidateNode),
neighbors,
uncheckedIndexes,
uncheckedCursor);
};
}

private boolean isWorstNonDiverse(
int candidateIndex, float[] candidateVector, NeighborArray neighbors) throws IOException {
int candidateIndex,
float[] candidateVector,
NeighborArray neighbors,
int[] uncheckedIndexes,
int uncheckedCursor)
throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
// the candidate itself is unchecked
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
} else {
// else we just need to make sure candidate does not violate diversity with the (newly
// inserted) unchecked nodes
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
for (int i = uncheckedCursor; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector,
(float[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
}
return false;
}

private boolean isWorstNonDiverse(
int candidateIndex, byte[] candidateVector, NeighborArray neighbors) throws IOException {
int candidateIndex,
byte[] candidateVector,
NeighborArray neighbors,
int[] uncheckedIndexes,
int uncheckedCursor)
throws IOException {
float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
if (candidateIndex == uncheckedIndexes[uncheckedCursor]) {
// the candidate itself is unchecked
for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector, (byte[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
} else {
// else we just need to make sure candidate does not violate diversity with the (newly
// inserted) unchecked nodes
assert candidateIndex > uncheckedIndexes[uncheckedCursor];
for (int i = uncheckedCursor; i >= 0; i--) {
float neighborSimilarity =
similarityFunction.compare(
candidateVector,
(byte[]) vectorsCopy.vectorValue(neighbors.node[uncheckedIndexes[i]]));
// candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return true;
}
}
}
return false;
Expand Down
88 changes: 72 additions & 16 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class NeighborArray {

float[] score;
int[] node;
private int sortedNodeSize;

public NeighborArray(int maxSize, boolean descOrder) {
node = new int[maxSize];
Expand All @@ -43,9 +44,10 @@ public NeighborArray(int maxSize, boolean descOrder) {

/**
* Add a new node to the NeighborArray. The new node must be worse than all previously stored
* nodes.
* nodes. This cannot be called after {@link #addOutOfOrder(int, float)}
*/
public void add(int newNode, float newScore) {
public void addInOrder(int newNode, float newScore) {
assert size == sortedNodeSize : "cannot call addInOrder after addOutOfOrder";
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
Expand All @@ -59,23 +61,72 @@ public void add(int newNode, float newScore) {
node[size] = newNode;
score[size] = newScore;
++size;
++sortedNodeSize;
}

/** Add a new node to the NeighborArray into a correct sort position according to its score. */
public void insertSorted(int newNode, float newScore) {
/** Add node and score but do not insert as sorted */
public void addOutOfOrder(int newNode, float newScore) {
if (size == node.length) {
node = ArrayUtil.grow(node);
score = ArrayUtil.growExact(score, node.length);
}
node[size] = newNode;
score[size] = newScore;
size++;
}

/**
* Sort the array according to scores, and return the sorted indexes of previous unsorted nodes
* (unchecked nodes)
*
* @return indexes of newly sorted (unchecked) nodes, in ascending order, or null if the array is
* already fully sorted
*/
public int[] sort() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. After reading I think I see what's going on. The name is deceptively simple; I wonder if it shouldn't hint that there is something tricky going on in here? Like sortAndReturnPreviouslyUnsorted() :) but it's fine to leave as is - the javadoc anyway explains it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha I'll just leave it as is. It claims itself is "sort" and it does the sort, but just return something different :)

if (size == sortedNodeSize) {
// all nodes checked and sorted
return null;
}
assert sortedNodeSize < size;
int[] uncheckedIndexes = new int[size - sortedNodeSize];
int count = 0;
while (sortedNodeSize != size) {
uncheckedIndexes[count] = insertSortedInternal(); // sortedNodeSize is increased inside
for (int i = 0; i < count; i++) {
if (uncheckedIndexes[i] >= uncheckedIndexes[count]) {
// the previous inserted nodes has been shifted
uncheckedIndexes[i]++;
}
}
count++;
}
Arrays.sort(uncheckedIndexes);
return uncheckedIndexes;
}

/** insert the first unsorted node into its sorted position */
private int insertSortedInternal() {
assert sortedNodeSize < size : "Call this method only when there's unsorted node";
int tmpNode = node[sortedNodeSize];
float tmpScore = score[sortedNodeSize];
int insertionPoint =
scoresDescOrder
? descSortFindRightMostInsertionPoint(newScore)
: ascSortFindRightMostInsertionPoint(newScore);
System.arraycopy(node, insertionPoint, node, insertionPoint + 1, size - insertionPoint);
System.arraycopy(score, insertionPoint, score, insertionPoint + 1, size - insertionPoint);
node[insertionPoint] = newNode;
score[insertionPoint] = newScore;
++size;
? descSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize)
: ascSortFindRightMostInsertionPoint(tmpScore, sortedNodeSize);
System.arraycopy(
node, insertionPoint, node, insertionPoint + 1, sortedNodeSize - insertionPoint);
System.arraycopy(
score, insertionPoint, score, insertionPoint + 1, sortedNodeSize - insertionPoint);
node[insertionPoint] = tmpNode;
score[insertionPoint] = tmpScore;
++sortedNodeSize;
return insertionPoint;
}

/** This method is for test only. */
void insertSorted(int newNode, float newScore) {
addOutOfOrder(newNode, newScore);
insertSortedInternal();
}

public int size() {
Expand All @@ -97,15 +148,20 @@ public float[] score() {

public void clear() {
size = 0;
sortedNodeSize = 0;
}

public void removeLast() {
size--;
sortedNodeSize = Math.min(sortedNodeSize, size);
}

public void removeIndex(int idx) {
System.arraycopy(node, idx + 1, node, idx, size - idx - 1);
System.arraycopy(score, idx + 1, score, idx, size - idx - 1);
if (idx < sortedNodeSize) {
sortedNodeSize--;
}
size--;
}

Expand All @@ -114,11 +170,11 @@ public String toString() {
return "NeighborArray[" + size + "]";
}

private int ascSortFindRightMostInsertionPoint(float newScore) {
int insertionPoint = Arrays.binarySearch(score, 0, size, newScore);
private int ascSortFindRightMostInsertionPoint(float newScore, int bound) {
int insertionPoint = Arrays.binarySearch(score, 0, bound, newScore);
if (insertionPoint >= 0) {
// find the right most position with the same score
while ((insertionPoint < size - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
while ((insertionPoint < bound - 1) && (score[insertionPoint + 1] == score[insertionPoint])) {
insertionPoint++;
}
insertionPoint++;
Expand All @@ -128,9 +184,9 @@ private int ascSortFindRightMostInsertionPoint(float newScore) {
return insertionPoint;
}

private int descSortFindRightMostInsertionPoint(float newScore) {
private int descSortFindRightMostInsertionPoint(float newScore, int bound) {
int start = 0;
int end = size - 1;
int end = bound - 1;
while (start <= end) {
int mid = (start + end) / 2;
if (score[mid] < newScore) end = mid - 1;
Expand Down
Loading