Skip to content

Commit

Permalink
make vectorCount atomic in MutablePQ
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Jan 9, 2025
1 parent bd21ee7 commit 72044bf
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.github.jbellis.jvector.vector.types.ByteSequence;

public class ImmutablePQVectors extends PQVectors {
private final int vectorCount;

/**
* Construct an immutable PQVectors instance with the given ProductQuantization and compressed data chunks.
* @param pq the ProductQuantization to use
Expand All @@ -37,4 +39,9 @@ public ImmutablePQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDa
protected int validChunkCount() {
return compressedDataChunks.length;
}

@Override
public int count() {
return vectorCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.github.jbellis.jvector.quantization;

import io.github.jbellis.jvector.vector.VectorizationProvider;
import java.util.concurrent.atomic.AtomicInteger;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
Expand All @@ -30,29 +31,31 @@ public class MutablePQVectors extends PQVectors implements MutableCompressedVect
private static final int INITIAL_CHUNKS = 10;
private static final float GROWTH_FACTOR = 1.5f;

protected AtomicInteger vectorCount;

/**
* Construct a mutable PQVectors instance with the given ProductQuantization.
* The vectors storage will grow dynamically as needed.
* @param pq the ProductQuantization to use
*/
public MutablePQVectors(ProductQuantization pq) {
super(pq);
this.vectorCount = 0;
this.vectorCount = new AtomicInteger(0);
this.vectorsPerChunk = VECTORS_PER_CHUNK;
this.compressedDataChunks = new ByteSequence<?>[INITIAL_CHUNKS];
}

@Override
public void encodeAndSet(int ordinal, VectorFloat<?> vector) {
ensureChunkCapacity(ordinal);
vectorCount = max(vectorCount, ordinal + 1);
vectorCount.updateAndGet(current -> max(current, ordinal + 1));
pq.encodeTo(vector, get(ordinal));
}

@Override
public void setZero(int ordinal) {
ensureChunkCapacity(ordinal);
vectorCount = max(vectorCount, ordinal + 1);
vectorCount.updateAndGet(current -> max(current, ordinal + 1));
get(ordinal).zero();
}

Expand All @@ -78,9 +81,14 @@ private void ensureChunkCapacity(int ordinal) {

@Override
protected int validChunkCount() {
if (vectorCount == 0)
if (vectorCount.get() == 0)
return 0;
int chunkOrdinal = (vectorCount - 1) / vectorsPerChunk;
int chunkOrdinal = (vectorCount.get() - 1) / vectorsPerChunk;
return chunkOrdinal + 1;
}

@Override
public int count() {
return vectorCount.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ public abstract class PQVectors implements CompressedVectors {

final ProductQuantization pq;
protected ByteSequence<?>[] compressedDataChunks;
protected int vectorCount;
protected int vectorsPerChunk;

protected PQVectors(ProductQuantization pq) {
Expand Down Expand Up @@ -155,19 +154,14 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk);
}

@Override
public int count() {
return vectorCount;
}

@Override
public void write(DataOutput out, int version) throws IOException
{
// pq codebooks
pq.write(out, version);

// compressed vectors
out.writeInt(vectorCount);
out.writeInt(count());
out.writeInt(pq.getSubspaceCount());
for (int i = 0; i < validChunkCount(); i++) {
vectorTypeSupport.writeByteSequence(out, compressedDataChunks[i]);
Expand Down Expand Up @@ -286,8 +280,8 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
}

public ByteSequence<?> get(int ordinal) {
if (ordinal < 0 || ordinal >= vectorCount)
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + vectorCount);
if (ordinal < 0 || ordinal >= count())
throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count());
return get(compressedDataChunks, ordinal, vectorsPerChunk, pq.getSubspaceCount());
}

Expand Down Expand Up @@ -341,7 +335,7 @@ public long ramBytesUsed() {
public String toString() {
return "PQVectors{" +
"pq=" + pq +
", count=" + vectorCount +
", count=" + count() +
'}';
}
}

0 comments on commit 72044bf

Please sign in to comment.