diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java index b1d538b3..0ac4c145 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java @@ -19,6 +19,8 @@ import io.github.jbellis.jvector.vector.types.ByteSequence; public class ImmutablePQVectors extends PQVectors { + private final int vectorCount; + /** * Construct an immutable PQVectors instance with the given ProductQuantization and compressed data chunks. * @param pq the ProductQuantization to use @@ -37,4 +39,9 @@ public ImmutablePQVectors(ProductQuantization pq, ByteSequence[] compressedDa protected int validChunkCount() { return compressedDataChunks.length; } + + @Override + public int count() { + return vectorCount; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java index 62e6a522..7627cf62 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.vector.VectorizationProvider; +import java.util.concurrent.atomic.AtomicInteger; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; @@ -30,6 +31,8 @@ public class MutablePQVectors extends PQVectors implements MutableCompressedVect private static final int INITIAL_CHUNKS = 10; private static final float GROWTH_FACTOR = 1.5f; + protected AtomicInteger vectorCount; + /** * Construct a mutable PQVectors instance with the given ProductQuantization. * The vectors storage will grow dynamically as needed. @@ -37,7 +40,7 @@ public class MutablePQVectors extends PQVectors implements MutableCompressedVect */ public MutablePQVectors(ProductQuantization pq) { super(pq); - this.vectorCount = 0; + this.vectorCount = new AtomicInteger(0); this.vectorsPerChunk = VECTORS_PER_CHUNK; this.compressedDataChunks = new ByteSequence[INITIAL_CHUNKS]; } @@ -45,14 +48,14 @@ public MutablePQVectors(ProductQuantization pq) { @Override public void encodeAndSet(int ordinal, VectorFloat vector) { ensureChunkCapacity(ordinal); - vectorCount = max(vectorCount, ordinal + 1); + vectorCount.updateAndGet(current -> max(current, ordinal + 1)); pq.encodeTo(vector, get(ordinal)); } @Override public void setZero(int ordinal) { ensureChunkCapacity(ordinal); - vectorCount = max(vectorCount, ordinal + 1); + vectorCount.updateAndGet(current -> max(current, ordinal + 1)); get(ordinal).zero(); } @@ -78,9 +81,14 @@ private void ensureChunkCapacity(int ordinal) { @Override protected int validChunkCount() { - if (vectorCount == 0) + if (vectorCount.get() == 0) return 0; - int chunkOrdinal = (vectorCount - 1) / vectorsPerChunk; + int chunkOrdinal = (vectorCount.get() - 1) / vectorsPerChunk; return chunkOrdinal + 1; } + + @Override + public int count() { + return vectorCount.get(); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 897bc610..f6a41cef 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -41,7 +41,6 @@ public abstract class PQVectors implements CompressedVectors { final ProductQuantization pq; protected ByteSequence[] compressedDataChunks; - protected int vectorCount; protected int vectorsPerChunk; protected PQVectors(ProductQuantization pq) { @@ -155,11 +154,6 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk); } - @Override - public int count() { - return vectorCount; - } - @Override public void write(DataOutput out, int version) throws IOException { @@ -167,7 +161,7 @@ public void write(DataOutput out, int version) throws IOException pq.write(out, version); // compressed vectors - out.writeInt(vectorCount); + out.writeInt(count()); out.writeInt(pq.getSubspaceCount()); for (int i = 0; i < validChunkCount(); i++) { vectorTypeSupport.writeByteSequence(out, compressedDataChunks[i]); @@ -286,8 +280,8 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, } public ByteSequence get(int ordinal) { - if (ordinal < 0 || ordinal >= vectorCount) - throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + vectorCount); + if (ordinal < 0 || ordinal >= count()) + throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count()); return get(compressedDataChunks, ordinal, vectorsPerChunk, pq.getSubspaceCount()); } @@ -341,7 +335,7 @@ public long ramBytesUsed() { public String toString() { return "PQVectors{" + "pq=" + pq + - ", count=" + vectorCount + + ", count=" + count() + '}'; } }