From a6fc86e7e925cbd5433794949450d5739095213f Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 28 Jan 2025 20:18:29 -0800 Subject: [PATCH] Address initial comments Signed-off-by: John Mazanec --- .../DerivedSourceStoredFieldsFormat.java | 20 ++--- .../DerivedSourceStoredFieldsReader.java | 86 +++++++++++++++---- .../DerivedSourceStoredFieldsWriter.java | 4 - .../derivedsource/DerivedSourceReaders.java | 11 ++- .../DerivedSourceVectorInjector.java | 14 ++- 5 files changed, 101 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java index 095222b54..55d8868dc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsFormat.java @@ -19,7 +19,6 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; -import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.io.IOException; @@ -40,24 +39,25 @@ public class DerivedSourceStoredFieldsFormat extends StoredFieldsFormat { @Override public StoredFieldsReader fieldsReader(Directory directory, SegmentInfo segmentInfo, FieldInfos fieldInfos, IOContext ioContext) throws IOException { - List derivedVectorFields = new ArrayList<>(); + List derivedVectorFields = null; for (FieldInfo fieldInfo : fieldInfos) { if (DERIVED_VECTOR_FIELD_ATTRIBUTE_TRUE_VALUE.equals(fieldInfo.attributes().get(DERIVED_VECTOR_FIELD_ATTRIBUTE_KEY))) { + // Lazily initialize the list of fields + if (derivedVectorFields == null) { + derivedVectorFields = new ArrayList<>(); + } derivedVectorFields.add(fieldInfo); } } - // If no fields have it enabled, - if (derivedVectorFields.isEmpty()) { + // If no fields have it enabled, we can just short-circuit and return the delegate's fieldReader + if (derivedVectorFields == null || derivedVectorFields.isEmpty()) { return delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext); } - DerivedSourceVectorInjector derivedSourceVectorInjector = new DerivedSourceVectorInjector( - derivedSourceReadersSupplier, - new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext), - derivedVectorFields - ); return new DerivedSourceStoredFieldsReader( delegate.fieldsReader(directory, segmentInfo, fieldInfos, ioContext), - derivedSourceVectorInjector + derivedVectorFields, + derivedSourceReadersSupplier, + new SegmentReadState(directory, segmentInfo, fieldInfos, ioContext) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java index ef9eba126..6c1ade140 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsReader.java @@ -5,24 +5,63 @@ package org.opensearch.knn.index.codec.KNN9120Codec; -import lombok.RequiredArgsConstructor; -import lombok.Setter; import org.apache.lucene.codecs.StoredFieldsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.StoredFieldVisitor; +import org.apache.lucene.util.IOUtils; import org.opensearch.index.fieldvisitor.FieldsVisitor; +import org.opensearch.knn.index.codec.derivedsource.DerivedSourceReadersSupplier; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceStoredFieldVisitor; import org.opensearch.knn.index.codec.derivedsource.DerivedSourceVectorInjector; import java.io.IOException; +import java.util.List; -@RequiredArgsConstructor public class DerivedSourceStoredFieldsReader extends StoredFieldsReader { private final StoredFieldsReader delegate; - // Given docId and source, process source + private final List derivedVectorFields; + private final DerivedSourceReadersSupplier derivedSourceReadersSupplier; + private final SegmentReadState segmentReadState; + private final boolean shouldInject; + private final DerivedSourceVectorInjector derivedSourceVectorInjector; - @Setter - private boolean shouldInject = true; + /** + * + * @param delegate delegate StoredFieldsReader + * @param derivedVectorFields List of fields that are derived source fields + * @param derivedSourceReadersSupplier Supplier for the derived source readers + * @param segmentReadState SegmentReadState for the segment + * @throws IOException in case of I/O error + */ + public DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState + ) throws IOException { + this(delegate, derivedVectorFields, derivedSourceReadersSupplier, segmentReadState, true); + } + + private DerivedSourceStoredFieldsReader( + StoredFieldsReader delegate, + List derivedVectorFields, + DerivedSourceReadersSupplier derivedSourceReadersSupplier, + SegmentReadState segmentReadState, + boolean shouldInject + ) throws IOException { + this.delegate = delegate; + this.derivedVectorFields = derivedVectorFields; + this.derivedSourceReadersSupplier = derivedSourceReadersSupplier; + this.segmentReadState = segmentReadState; + this.shouldInject = shouldInject; + this.derivedSourceVectorInjector = createDerivedSourceVectorInjector(); + } + + private DerivedSourceVectorInjector createDerivedSourceVectorInjector() throws IOException { + return new DerivedSourceVectorInjector(derivedSourceReadersSupplier, segmentReadState, derivedVectorFields); + } @Override public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IOException { @@ -43,7 +82,17 @@ public void document(int docId, StoredFieldVisitor storedFieldVisitor) throws IO @Override public StoredFieldsReader clone() { - return new DerivedSourceStoredFieldsReader(delegate.clone(), derivedSourceVectorInjector); + try { + return new DerivedSourceStoredFieldsReader( + delegate.clone(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + shouldInject + ); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Override @@ -53,22 +102,27 @@ public void checkIntegrity() throws IOException { @Override public void close() throws IOException { - delegate.close(); + IOUtils.close(delegate, derivedSourceVectorInjector); } /** * For merging, we need to tell the derived source stored fields reader to skip injecting the source. Otherwise, * on merge we will end up just writing the source to disk * - * @param storedFieldsReader stored fields reader to wrap - * @return wrapped stored fields reader + * @return Merged instance that wont inject by default */ - public static StoredFieldsReader wrapForMerge(StoredFieldsReader storedFieldsReader) { - if (storedFieldsReader instanceof DerivedSourceStoredFieldsReader) { - StoredFieldsReader storedFieldsReaderClone = storedFieldsReader.clone(); - ((DerivedSourceStoredFieldsReader) storedFieldsReaderClone).setShouldInject(false); - return storedFieldsReaderClone; + @Override + public StoredFieldsReader getMergeInstance() { + try { + return new DerivedSourceStoredFieldsReader( + delegate.getMergeInstance(), + derivedVectorFields, + derivedSourceReadersSupplier, + segmentReadState, + false + ); + } catch (IOException e) { + throw new RuntimeException(e); } - return storedFieldsReader; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java index 1b3c8b3b1..b01da6001 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/DerivedSourceStoredFieldsWriter.java @@ -65,10 +65,6 @@ public void writeField(FieldInfo info, DataInput value, int length) throws IOExc @Override public int merge(MergeState mergeState) throws IOException { - // We have to wrap these here to avoid storing the vectors during merge - for (int i = 0; i < mergeState.storedFieldsReaders.length; i++) { - mergeState.storedFieldsReaders[i] = DerivedSourceStoredFieldsReader.wrapForMerge(mergeState.storedFieldsReaders[i]); - } return delegate.merge(mergeState); } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java index 5bdcc5181..1b3cdb3f8 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceReaders.java @@ -10,14 +10,23 @@ import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.util.IOUtils; + +import java.io.Closeable; +import java.io.IOException; /** * Class holds the readers necessary to implement derived source. */ @RequiredArgsConstructor @Getter -public class DerivedSourceReaders { +public class DerivedSourceReaders implements Closeable { private final KnnVectorsReader knnVectorsReader; private final DocValuesProducer docValuesProducer; private final FieldsProducer fieldsProducer; + + @Override + public void close() throws IOException { + IOUtils.close(knnVectorsReader, docValuesProducer, fieldsProducer); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java index 7218735ec..c59bd4379 100644 --- a/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java +++ b/src/main/java/org/opensearch/knn/index/codec/derivedsource/DerivedSourceVectorInjector.java @@ -8,6 +8,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.util.IOUtils; import org.opensearch.common.collect.Tuple; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentHelper; @@ -17,6 +18,7 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; +import java.io.Closeable; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -31,8 +33,9 @@ * format readers and information about the fields to inject vectors into the source. */ @Log4j2 -public class DerivedSourceVectorInjector { +public class DerivedSourceVectorInjector implements Closeable { + private final DerivedSourceReaders derivedSourceReaders; private final List perFieldDerivedVectorInjectors; private final Set fieldNames; @@ -48,7 +51,7 @@ public DerivedSourceVectorInjector( SegmentReadState segmentReadState, List fieldsToInjectVector ) throws IOException { - DerivedSourceReaders derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); + this.derivedSourceReaders = derivedSourceReadersSupplier.getReaders(segmentReadState); this.perFieldDerivedVectorInjectors = new ArrayList<>(); this.fieldNames = new HashSet<>(); for (FieldInfo fieldInfo : fieldsToInjectVector) { @@ -67,7 +70,7 @@ public DerivedSourceVectorInjector( * @return byte array of the source with the vector fields added * @throws IOException if there is an issue reading from the formats */ - public byte[] injectVectors(Integer docId, byte[] sourceAsBytes) throws IOException { + public byte[] injectVectors(int docId, byte[] sourceAsBytes) throws IOException { // Deserialize the source into a modifiable map Tuple> mapTuple = XContentHelper.convertToMap( BytesReference.fromByteBuffer(ByteBuffer.wrap(sourceAsBytes)), @@ -121,4 +124,9 @@ public boolean shouldInject(String[] includes, String[] excludes) { } return true; } + + @Override + public void close() throws IOException { + IOUtils.close(derivedSourceReaders); + } }