Skip to content

Commit

Permalink
Address initial comments
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 committed Jan 29, 2025
1 parent 4bc2828 commit a6fc86e
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;

import java.io.IOException;
Expand All @@ -40,24 +39,25 @@ public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat {
@Override
public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext)
throws IOException {
List<FieldInfo> derivedVectorFields = new ArrayList<>();
List<FieldInfo> derivedVectorFields = null;
for (FieldInfo fieldInfo : fieldInfos) {
if (DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE.equals(fieldInfo.attributes().get(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY))) {
// Lazily initialize the list of fields
if (derivedVectorFields == null) {
derivedVectorFields = new ArrayList<>();
}
derivedVectorFields.add(fieldInfo);
}
}
// If no fields have it enabled,
if (derivedVectorFields.isEmpty()) {
// If no fields have it enabled, we can just short-circuit and return the delegate's fieldReader
if (derivedVectorFields == null || derivedVectorFields.isEmpty()) {
return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext);
}
DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector(
derivedSourceReadersSupplier,
new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext),
derivedVectorFields
);
return new DerivedSourceStoredFieldsReader(
delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext),
derivedSourceVectorInjector
derivedVectorFields,
derivedSourceReadersSupplier,
new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,63 @@

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.StoredFieldVisitor;
import org.apache.lucene.util.IOUtils;
import org.opensearch.index.fieldvisitor.FieldsVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor;
import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector;

import java.io.IOException;
import java.util.List;

@RequiredArgsConstructor
public class DerivedSourceStoredFieldsReader extends StoredFieldsReader {
private final StoredFieldsReader delegate;
// Given docId and source, process source
private final List<FieldInfo> derivedVectorFields;
private final DerivedSourceReadersSupplier derivedSourceReadersSupplier;
private final SegmentReadState segmentReadState;
private final boolean shouldInject;

private final DerivedSourceVectorInjector derivedSourceVectorInjector;

@Setter
private boolean shouldInject = true;
/**
*
* @param delegate delegate StoredFieldsReader
* @param derivedVectorFields List of fields that are derived source fields
* @param derivedSourceReadersSupplier Supplier for the derived source readers
* @param segmentReadState SegmentReadState for the segment
* @throws IOException in case of I/O error
*/
public DerivedSourceStoredFieldsReader(
StoredFieldsReader delegate,
List<FieldInfo> derivedVectorFields,
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
SegmentReadState segmentReadState
) throws IOException {
this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true);
}

private DerivedSourceStoredFieldsReader(
StoredFieldsReader delegate,
List<FieldInfo> derivedVectorFields,
DerivedSourceReadersSupplier derivedSourceReadersSupplier,
SegmentReadState segmentReadState,
boolean shouldInject
) throws IOException {
this.delegate = delegate;
this.derivedVectorFields = derivedVectorFields;
this.derivedSourceReadersSupplier = derivedSourceReadersSupplier;
this.segmentReadState = segmentReadState;
this.shouldInject = shouldInject;
this.derivedSourceVectorInjector = createDerivedSourceVectorInjector();
}

private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException {
return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields);
}

@Override
public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException {
Expand All @@ -43,7 +82,17 @@ public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IO

@Override
public StoredFieldsReader clone() {
return new DerivedSourceStoredFieldsReader(delegate.clone(), derivedSourceVectorInjector);
try {
return new DerivedSourceStoredFieldsReader(
delegate.clone(),
derivedVectorFields,
derivedSourceReadersSupplier,
segmentReadState,
shouldInject
);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@Override
Expand All @@ -53,22 +102,27 @@ public void checkIntegrity() throws IOException {

@Override
public void close() throws IOException {
delegate.close();
IOUtils.close(delegate, derivedSourceVectorInjector);
}

/**
* For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise,
* on merge we will end up just writing the source to disk
*
* @param storedFieldsReader stored fields reader to wrap
* @return wrapped stored fields reader
* @return Merged instance that wont inject by default
*/
public static StoredFieldsReader wrapForMerge(StoredFieldsReader storedFieldsReader) {
if (storedFieldsReader instanceof DerivedSourceStoredFieldsReader) {
StoredFieldsReader storedFieldsReaderClone = storedFieldsReader.clone();
((DerivedSourceStoredFieldsReader) storedFieldsReaderClone).setShouldInject(false);
return storedFieldsReaderClone;
@Override
public StoredFieldsReader getMergeInstance() {
try {
return new DerivedSourceStoredFieldsReader(
delegate.getMergeInstance(),
derivedVectorFields,
derivedSourceReadersSupplier,
segmentReadState,
false
);
} catch (IOException e) {
throw new RuntimeException(e);
}
return storedFieldsReader;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ public void writeField(FieldInfo info, DataInput value, int length) throws IOExc

@Override
public int merge(MergeState mergeState) throws IOException {
// We have to wrap these here to avoid storing the vectors during merge
for (int i = 0; i < mergeState.storedFieldsReaders.length; i++) {
mergeState.storedFieldsReaders[i] = DerivedSourceStoredFieldsReader.wrapForMerge(mergeState.storedFieldsReaders[i]);
}
return delegate.merge(mergeState);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,23 @@
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.util.IOUtils;

import java.io.Closeable;
import java.io.IOException;

/**
* Class holds the readers necessary to implement derived source.
*/
@RequiredArgsConstructor
@Getter
public class DerivedSourceReaders {
public class DerivedSourceReaders implements Closeable {
private final KnnVectorsReader knnVectorsReader;
private final DocValuesProducer docValuesProducer;
private final FieldsProducer fieldsProducer;

@Override
public void close() throws IOException {
IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.util.IOUtils;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentHelper;
Expand All @@ -17,6 +18,7 @@
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand All @@ -31,8 +33,9 @@
* format readers and information about the fields to inject vectors into the source.
*/
@Log4j2
public class DerivedSourceVectorInjector {
public class DerivedSourceVectorInjector implements Closeable {

private final DerivedSourceReaders derivedSourceReaders;
private final List<PerFieldDerivedVectorInjector> perFieldDerivedVectorInjectors;
private final Set<String> fieldNames;

Expand All @@ -48,7 +51,7 @@ public DerivedSourceVectorInjector(
SegmentReadState segmentReadState,
List<FieldInfo> fieldsToInjectVector
) throws IOException {
DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState);
this.perFieldDerivedVectorInjectors = new ArrayList<>();
this.fieldNames = new HashSet<>();
for (FieldInfo fieldInfo : fieldsToInjectVector) {
Expand All @@ -67,7 +70,7 @@ public DerivedSourceVectorInjector(
* @return byte array of the source with the vector fields added
* @throws IOException if there is an issue reading from the formats
*/
public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOException {
public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException {
// Deserialize the source into a modifiable map
Tuple<? extends MediaType, Map<String, Object>> mapTuple = XContentHelper.convertToMap(
BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)),
Expand Down Expand Up @@ -121,4 +124,9 @@ public boolean shouldInject(String[] includes, String[] excludes) {
}
return true;
}

@Override
public void close() throws IOException {
IOUtils.close(derivedSourceReaders);
}
}

0 comments on commit a6fc86e

Please sign in to comment.