Skip to content

Commit

Permalink
Fix bytes bug and add uTs for derived source
Browse files Browse the repository at this point in the history
Fixes a bug in the derived source writer where we are reading the entire
bytes array from the bytes ref instead of just the offset+length.

Along with that, touches up the ParentChildHelper (no prod impact) and
also adds some unit tests.

Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 authored and Vikasht34 committed Feb 12, 2025
1 parent 878c57d commit d821749
Show file tree
Hide file tree
Showing 12 changed files with 445 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.StoredFieldsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.KNN9120Codec.DerivedSourceStoredFieldsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 10.0.1
Expand All @@ -24,12 +28,15 @@ public class KNN10010Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_10_01_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final StoredFieldsFormat storedFieldsFormat;

private final MapperService mapperService;

/**
* No arg constructor that uses Lucene99 as the delegate
*/
public KNN10010Codec() {
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat(), null);
}

/**
Expand All @@ -40,10 +47,12 @@ public KNN10010Codec() {
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
protected KNN10010Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat, MapperService mapperService) {
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
this.mapperService = mapperService;
this.storedFieldsFormat = getStoredFieldsFormat();
}

@Override
Expand All @@ -60,4 +69,36 @@ public CompoundFormat compoundFormat() {
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}

@Override
public StoredFieldsFormat storedFieldsFormat() {
return storedFieldsFormat;
}

private StoredFieldsFormat getStoredFieldsFormat() {
DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> {
if (segmentReadState.fieldInfos.hasVectorValues()) {
return knnVectorsFormat().fieldsReader(segmentReadState);
}
return null;
}, (segmentReadState) -> {
if (segmentReadState.fieldInfos.hasDocValues()) {
return docValuesFormat().fieldsProducer(segmentReadState);
}
return null;

}, (segmentReadState) -> {
if (segmentReadState.fieldInfos.hasPostings()) {
return postingsFormat().fieldsProducer(segmentReadState);
}
return null;

}, (segmentReadState -> {
if (segmentReadState.fieldInfos.hasNorms()) {
return normsFormat().normsProducer(segmentReadState);
}
return null;
}));
return new DerivedSourceStoredFieldsFormat(delegate.storedFieldsFormat(), derivedSourceReadersSupplier, mapperService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void writeField(FieldInfo fieldInfo, BytesRef bytesRef) throws IOExceptio
// Reference:
// https://github.com/opensearch-project/OpenSearch/blob/2.18.0/server/src/main/java/org/opensearch/index/mapper/SourceFieldMapper.java#L322
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes)),
BytesReference.fromByteBuffer(ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)),
true,
MediaTypeRegistry.JSON
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;
private final StoredFieldsFormat storedFieldsFormat;

private final MapperService mapperService;

Expand All @@ -48,6 +49,7 @@ protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
this.mapperService = mapperService;
this.storedFieldsFormat = getStoredFieldsFormat();
}

@Override
Expand All @@ -67,6 +69,10 @@ public KnnVectorsFormat knnVectorsFormat() {

@Override
public StoredFieldsFormat storedFieldsFormat() {
return storedFieldsFormat;
}

private StoredFieldsFormat getStoredFieldsFormat() {
DerivedSourceReadersSupplier derivedSourceReadersSupplier = new DerivedSourceReadersSupplier((segmentReadState) -> {
if (segmentReadState.fieldInfos.hasVectorValues()) {
return knnVectorsFormat().fieldsReader(segmentReadState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ public enum KNNCodecVersion {
(userCodec, mapperService) -> KNN10010Codec.builder()
.delegate(userCodec)
.knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.mapperService(mapperService)
.build(),
KNN10010Codec::new
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ public class ParentChildHelper {
* this would return "parent.to".
*
* @param field nested field path
* @return parent field path without the child
* @return parent field path without the child. Null if no parent exists
*/
public static String getParentField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
Expand All @@ -30,10 +33,16 @@ public static String getParentField(String field) {
* return "child".
*
* @param field nested field path
* @return child field path without the parent path
* @return child field path without the parent path. Null if no child exists
*/
public static String getChildField(String field) {
if (field == null) {
return null;
}
int lastDot = field.lastIndexOf('.');
if (lastDot == -1) {
return null;
}
return field.substring(lastDot + 1);
}

Expand All @@ -46,7 +55,11 @@ public static String getChildField(String field) {
* @return sibling field path
*/
public static String constructSiblingField(String field, String sibling) {
return getParentField(field) + "." + sibling;
String parent = getParentField(field);
if (parent == null) {
return sibling;
}
return parent + "." + sibling;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
/**
* Provides different strategies to extract the vectors from different {@link KNNVectorValuesIterator}
*/
interface VectorValueExtractorStrategy {
public interface VectorValueExtractorStrategy {

/**
* Extract a float vector from KNNVectorValuesIterator.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.StoredFieldsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import java.util.List;
import java.util.Map;

import static org.mockito.Mockito.mock;

public class DerivedSourceStoredFieldsWriterTests extends KNNTestCase {

@SneakyThrows
public void testWriteField() {
StoredFieldsWriter delegate = mock(StoredFieldsWriter.class);
FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build();
List<String> fields = List.of("test");

DerivedSourceStoredFieldsWriter derivedSourceStoredFieldsWriter = new DerivedSourceStoredFieldsWriter(delegate, fields);

Map<String, Object> source = Map.of("test", new float[] { 1.0f, 2.0f, 3.0f }, "text_field", "text_value");
BytesStreamOutput bStream = new BytesStreamOutput();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(MediaTypeRegistry.JSON, bStream).map(source);
builder.close();
byte[] originalBytes = bStream.bytes().toBytesRef().bytes;
byte[] shiftedBytes = new byte[originalBytes.length + 2];
System.arraycopy(originalBytes, 0, shiftedBytes, 1, originalBytes.length);
derivedSourceStoredFieldsWriter.writeField(fieldInfo, new BytesRef(shiftedBytes, 1, originalBytes.length));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.derivedsource;

import org.apache.lucene.index.StoredFieldVisitor;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DerivedSourceStoredFieldVisitorTests extends KNNTestCase {

public void testBinaryField() throws Exception {
StoredFieldVisitor delegate = mock(StoredFieldVisitor.class);
doAnswer(invocationOnMock -> null).when(delegate).binaryField(any(), any());
DerivedSourceVectorInjector derivedSourceVectorInjector = mock(DerivedSourceVectorInjector.class);
when(derivedSourceVectorInjector.injectVectors(anyInt(), any())).thenReturn(new byte[0]);
DerivedSourceStoredFieldVisitor derivedSourceStoredFieldVisitor = new DerivedSourceStoredFieldVisitor(
delegate,
0,
derivedSourceVectorInjector
);

// When field is not _source, then do not call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("test").build(), null);
verify(derivedSourceVectorInjector, times(0)).injectVectors(anyInt(), any());
verify(delegate, times(1)).binaryField(any(), any());

// When field is not _source, then do call the injector
derivedSourceStoredFieldVisitor.binaryField(KNNCodecTestUtil.FieldInfoBuilder.builder("_source").build(), null);
verify(derivedSourceVectorInjector, times(1)).injectVectors(anyInt(), any());
verify(delegate, times(2)).binaryField(any(), any());
}
}
Loading

0 comments on commit d821749

Please sign in to comment.