Skip to content

Commit

Permalink
Re-Call Issue Fix with Binary Quantized Vectors (opensearch-project#2071
Browse files Browse the repository at this point in the history
)

* Re-Call Issue Fix with Binary Quantized Vectors

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>

* Feedback Fix

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>

---------

Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
Signed-off-by: Vikasht34 <viktari@amazon.com>
  • Loading branch information
Vikasht34 authored Sep 9, 2024
1 parent 8f6b177 commit ce735c4
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
class QuantizationIndexUtils {

/**
* Processes and returns the vector based on whether quantization is applied or not.
* Processes the vector from {@link KNNVectorValues} and returns either a cloned quantized vector or a cloned original vector.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @return the processed vector, either quantized or original.
* @throws IOException if an I/O error occurs during processing.
* @param knnVectorValues The KNN vector values containing the original vector.
* @param indexBuildSetup The setup containing the quantization state and output details.
* @return The quantized vector (as a byte array) or the original/cloned vector.
* @throws IOException If an I/O error occurs while processing the vector.
*/
static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
Expand All @@ -33,7 +33,11 @@ static Object processAndReturnVector(KNNVectorValues<?> knnVectorValues, IndexBu
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return indexBuildSetup.getQuantizationOutput().getQuantizedVector();
/**
* Returns a copy of the quantized vector. This is because of during transfer same vectors was getting
* added due to reference.
*/
return indexBuildSetup.getQuantizationOutput().getQuantizedVectorCopy();
} else {
return knnVectorValues.conditionalCloneVector();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,16 @@ public boolean isPrepared(int vectorLength) {
public byte[] getQuantizedVector() {
return quantizedVector;
}

/**
* Returns a copy of the quantized vector.
*
* @return a copy of the quantized vector byte array.
*/
@Override
public byte[] getQuantizedVectorCopy() {
byte[] clonedByteArray = new byte[quantizedVector.length];
System.arraycopy(quantizedVector, 0, clonedByteArray, 0, quantizedVector.length);
return clonedByteArray;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,32 @@ public interface QuantizationOutput<T> {
/**
* Returns the quantized vector.
*
* @return the quantized data.
* This method provides access to the quantized data in its current state.
* It returns the same reference to the internal quantized vector on each call, meaning any modifications
* to the returned array will directly affect the internal state of the object. This design is intentional
* to avoid unnecessary copying of data and improve performance, especially in scenarios where frequent
* access to the quantized vector is required.
*
* <p><b>Important:</b> As this method returns a direct reference to the internal array, care must be taken
* when modifying the returned array. If the returned vector is altered, the changes will reflect in the
* quantized vector managed by the object, which could lead to unintended side effects.</p>
*
* <p><b>Usage Example:</b></p>
* <pre>
* byte[] quantizedData = quantizationOutput.getQuantizedVector();
* // Use or modify quantizedData, but be cautious that changes affect the internal state.
* </pre>
*
* This method does not create a deep copy of the vector to avoid performance overhead in real-time
* or high-frequency operations. If a separate copy of the vector is needed, the caller should manually
* clone or copy the returned array.
*
* <p><b>Example to clone the array:</b></p>
* <pre>
* byte[] clonedData = Arrays.copyOf(quantizationOutput.getQuantizedVector(), quantizationOutput.getQuantizedVector().length);
* </pre>
*
* @return the quantized vector (same reference on each invocation).
*/
T getQuantizedVector();

Expand All @@ -33,4 +58,11 @@ public interface QuantizationOutput<T> {
* @return true if the quantized vector is already prepared, false otherwise.
*/
boolean isPrepared(int vectorLength);

/**
* Returns a copy of the quantized vector.
*
* @return a copy of the quantized data.
*/
T getQuantizedVectorCopy();
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds[0].length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, output.getQuantizedVector());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds.length != vectorLength) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@ public void testBuildAndWrite_withQuantization() {
ArgumentCaptor<float[]> vectorCaptor = ArgumentCaptor.forClass(float[].class);
// New: Create QuantizationOutput and mock the quantization process
QuantizationOutput<byte[]> quantizationOutput = mock(QuantizationOutput.class);
when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 });
when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 });
when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn(
quantizationOutput
);

// Quantize the vector with the quantization output
when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer(
invocation -> {
quantizationOutput.getQuantizedVector();
return quantizationOutput.getQuantizedVector();
quantizationOutput.getQuantizedVectorCopy();
return quantizationOutput.getQuantizedVectorCopy();
}
);
when(quantizationState.getDimensions()).thenReturn(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ public void testBuildAndWrite_withQuantization() {
ArgumentCaptor<float[]> vectorCaptor = ArgumentCaptor.forClass(float[].class);
// New: Create QuantizationOutput and mock the quantization process
QuantizationOutput<byte[]> quantizationOutput = mock(QuantizationOutput.class);
when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 });
when(quantizationOutput.getQuantizedVectorCopy()).thenReturn(new byte[] { 1, 2 });
when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn(
quantizationOutput
);

// Quantize the vector with the quantization output
when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer(
invocation -> {
quantizationOutput.getQuantizedVector();
return quantizationOutput.getQuantizedVector();
quantizationOutput.getQuantizedVectorCopy();
return quantizationOutput.getQuantizedVectorCopy();
}
);
when(quantizationState.getDimensions()).thenReturn(2);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.quantization.output;

import org.junit.Before;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;

public class BinaryQuantizationOutputTests extends KNNTestCase {

private static final int BITS_PER_COORDINATE = 1;
private BinaryQuantizationOutput quantizationOutput;

@Before
public void setUp() throws Exception {
super.setUp();
quantizationOutput = new BinaryQuantizationOutput(BITS_PER_COORDINATE);
}

public void testPrepareQuantizedVector_ShouldInitializeCorrectly_WhenVectorLengthIsValid() {
// Arrange
int vectorLength = 10;

// Act
quantizationOutput.prepareQuantizedVector(vectorLength);

// Assert
assertNotNull(quantizationOutput.getQuantizedVector());
}

public void testPrepareQuantizedVector_ShouldThrowException_WhenVectorLengthIsZeroOrNegative() {
// Act and Assert
expectThrows(IllegalArgumentException.class, () -> quantizationOutput.prepareQuantizedVector(0));
expectThrows(IllegalArgumentException.class, () -> quantizationOutput.prepareQuantizedVector(-1));
}

public void testIsPrepared_ShouldReturnTrue_WhenCalledWithSameVectorLength() {
// Arrange
int vectorLength = 8;
quantizationOutput.prepareQuantizedVector(vectorLength);
// Act and Assert
assertTrue(quantizationOutput.isPrepared(vectorLength));
}

public void testIsPrepared_ShouldReturnFalse_WhenCalledWithDifferentVectorLength() {
// Arrange
int vectorLength = 8;
quantizationOutput.prepareQuantizedVector(vectorLength);
// Act and Assert
assertFalse(quantizationOutput.isPrepared(vectorLength + 1));
}

public void testGetQuantizedVector_ShouldReturnSameReference() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);
// Act
byte[] vector = quantizationOutput.getQuantizedVector();
// Assert
assertEquals(vector, quantizationOutput.getQuantizedVector());
}

public void testGetQuantizedVectorCopy_ShouldReturnCopyOfVector() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);

// Act
byte[] vectorCopy = quantizationOutput.getQuantizedVectorCopy();

// Assert
assertNotSame(vectorCopy, quantizationOutput.getQuantizedVector());
assertArrayEquals(vectorCopy, quantizationOutput.getQuantizedVector());
}

public void testGetQuantizedVectorCopy_ShouldReturnNewCopyOnEachCall() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);

// Act
byte[] vectorCopy1 = quantizationOutput.getQuantizedVectorCopy();
byte[] vectorCopy2 = quantizationOutput.getQuantizedVectorCopy();

// Assert
assertNotSame(vectorCopy1, vectorCopy2);
}

public void testPrepareQuantizedVector_ShouldResetQuantizedVector_WhenCalledWithDifferentLength() {
// Arrange
int initialLength = 5;
int newLength = 10;
quantizationOutput.prepareQuantizedVector(initialLength);
byte[] initialVector = quantizationOutput.getQuantizedVector();

// Act
quantizationOutput.prepareQuantizedVector(newLength);
byte[] newVector = quantizationOutput.getQuantizedVector();

// Assert
assertNotSame(initialVector, newVector); // The array reference should change
assertEquals(newVector.length, (BITS_PER_COORDINATE * newLength + 7) / 8); // Correct size for new vector
}

public void testPrepareQuantizedVector_ShouldRetainSameArray_WhenCalledWithSameLength() {
// Arrange
int vectorLength = 5;
quantizationOutput.prepareQuantizedVector(vectorLength);
byte[] initialVector = quantizationOutput.getQuantizedVector();

// Act
quantizationOutput.prepareQuantizedVector(vectorLength);
byte[] newVector = quantizationOutput.getQuantizedVector();

// Assert
assertSame(newVector, initialVector); // The array reference should remain the same
}
}

0 comments on commit ce735c4

Please sign in to comment.