diff --git a/index/scorch/optimize_knn.go b/index/scorch/optimize_knn.go index 6b10a207c..2222f3d95 100644 --- a/index/scorch/optimize_knn.go +++ b/index/scorch/optimize_knn.go @@ -23,6 +23,7 @@ import ( "sync" "sync/atomic" + "github.com/RoaringBitmap/roaring" "github.com/blevesearch/bleve/v2/search" index "github.com/blevesearch/bleve_index_api" segment_api "github.com/blevesearch/scorch_segment_api/v2" @@ -62,6 +63,8 @@ func (o *OptimizeVR) Finish() error { var errorsM sync.Mutex var errors []error + snapshotGlobalDocNums := o.snapshot.globalDocNums() + defer o.invokeSearcherEndCallback() wg := sync.WaitGroup{} @@ -89,9 +92,24 @@ func (o *OptimizeVR) Finish() error { vectorIndexSize := vecIndex.Size() origSeg.cachedMeta.updateMeta(field, vectorIndexSize) for _, vr := range vrs { + eligibleVectorInternalIDs := vr.EligibleDocIDs() + // Only the eligible documents belonging to this segment + // will get filtered out. + // There is no way to determine which doc belongs to which segment + eligibleVectorInternalIDs.And(snapshotGlobalDocNums[index]) + + eligibleLocalDocNums := roaring.NewBitmap() + // get the (segment-)local document numbers + for _, docNum := range eligibleVectorInternalIDs.ToArray() { + localDocNum := o.snapshot.localDocNumFromGlobal(index, + uint64(docNum)) + eligibleLocalDocNums.Add(uint32(localDocNum)) + } + // for each VR, populate postings list and iterators // by passing the obtained vector index and getting similar vectors. - pl, err := vecIndex.Search(vr.vector, vr.k, vr.searchParams) + pl, err := vecIndex.Search(vr.vector, vr.k, + eligibleLocalDocNums, vr.searchParams) if err != nil { errorsM.Lock() errors = append(errors, err) diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index f0e7ae1cf..e8fd4882c 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -471,16 +471,35 @@ func (is *IndexSnapshot) Document(id string) (rv index.Document, err error) { return rvd, nil } +func (is *IndexSnapshot) localDocNumFromGlobal(segmentIndex int, docNum uint64) uint64 { + return docNum - is.offsets[segmentIndex] +} + func (is *IndexSnapshot) segmentIndexAndLocalDocNumFromGlobal(docNum uint64) (int, uint64) { segmentIndex := sort.Search(len(is.offsets), func(x int) bool { return is.offsets[x] > docNum }) - 1 - localDocNum := docNum - is.offsets[segmentIndex] + localDocNum := is.localDocNumFromGlobal(segmentIndex, docNum) return int(segmentIndex), localDocNum } +// function to get return a mapping of the segment index to the live global +// +// doc nums in the segment at the specified index snapshot. +func (is *IndexSnapshot) globalDocNums() map[int]*roaring.Bitmap { + segmentIndexGlobalDocNums := make(map[int]*roaring.Bitmap) + + for i := range is.segment { + segmentIndexGlobalDocNums[i] = roaring.NewBitmap() + for _, localDocNum := range is.segment[i].DocNumbersLive().ToArray() { + segmentIndexGlobalDocNums[i].Add(localDocNum + uint32(is.offsets[i])) + } + } + return segmentIndexGlobalDocNums +} + func (is *IndexSnapshot) ExternalID(id index.IndexInternalID) (string, error) { docNum, err := docInternalToNumber(id) if err != nil { diff --git a/index/scorch/snapshot_index_vr.go b/index/scorch/snapshot_index_vr.go index 30e03dcba..e6162a4c7 100644 --- a/index/scorch/snapshot_index_vr.go +++ b/index/scorch/snapshot_index_vr.go @@ -24,6 +24,7 @@ import ( "fmt" "reflect" + "github.com/RoaringBitmap/roaring" "github.com/blevesearch/bleve/v2/size" index "github.com/blevesearch/bleve_index_api" segment_api "github.com/blevesearch/scorch_segment_api/v2" @@ -50,7 +51,18 @@ type IndexSnapshotVectorReader struct { currID index.IndexInternalID ctx context.Context - searchParams json.RawMessage + searchParams json.RawMessage + eligibleDocIDs []index.IndexInternalID +} + +func (i *IndexSnapshotVectorReader) EligibleDocIDs() *roaring.Bitmap { + res := roaring.NewBitmap() + // converts the doc IDs to uint32 and returns + for _, eligibleDocInternalID := range i.eligibleDocIDs { + internalDocID, _ := docInternalToNumber(index.IndexInternalID(eligibleDocInternalID)) + res.Add(uint32(internalDocID)) + } + return res } func (i *IndexSnapshotVectorReader) Size() int { @@ -108,7 +120,8 @@ func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID, preAlloced *index.VectorDoc) (*index.VectorDoc, error) { if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 { - i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k, i.searchParams) + i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k, + i.searchParams, i.eligibleDocIDs) if err != nil { return nil, err } diff --git a/index/scorch/snapshot_vector_index.go b/index/scorch/snapshot_vector_index.go index 70546d4e3..c8985802b 100644 --- a/index/scorch/snapshot_vector_index.go +++ b/index/scorch/snapshot_vector_index.go @@ -26,15 +26,16 @@ import ( ) func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32, - field string, k int64, searchParams json.RawMessage) ( + field string, k int64, searchParams json.RawMessage, filterIDs []index.IndexInternalID) ( index.VectorReader, error) { rv := &IndexSnapshotVectorReader{ - vector: vector, - field: field, - k: k, - snapshot: is, - searchParams: searchParams, + vector: vector, + field: field, + k: k, + snapshot: is, + searchParams: searchParams, + eligibleDocIDs: filterIDs, } if rv.postings == nil { diff --git a/search/collector/eligible.go b/search/collector/eligible.go new file mode 100644 index 000000000..eecb4f88f --- /dev/null +++ b/search/collector/eligible.go @@ -0,0 +1,152 @@ +package collector + +import ( + "context" + "time" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +type EligibleCollector struct { + size int + total uint64 + maxScore float64 + took time.Duration + results search.DocumentMatchCollection + + store collectorStore + + needDocIds bool + neededFields []string + cachedDesc []bool + + lowestMatchOutsideResults *search.DocumentMatch + updateFieldVisitor index.DocValueVisitor + dvReader index.DocValueReader + searchAfter *search.DocumentMatch +} + +func NewEligibleCollector(size int) *EligibleCollector { + return newEligibleCollector(size) +} + +func newEligibleCollector(size int) *EligibleCollector { + // No sort order & skip always 0 since this is only to filter eligible docs. + hc := &EligibleCollector{size: size} + + // comparator is a dummy here + hc.store = getOptimalCollectorStore(size, 0, func(i, j *search.DocumentMatch) int { + return 0 + }) + + return hc +} + +func (hc *EligibleCollector) Collect(ctx context.Context, searcher search.Searcher, reader index.IndexReader) error { + startTime := time.Now() + var err error + var next *search.DocumentMatch + + backingSize := hc.size + if backingSize > PreAllocSizeSkipCap { + backingSize = PreAllocSizeSkipCap + 1 + } + searchContext := &search.SearchContext{ + DocumentMatchPool: search.NewDocumentMatchPool(backingSize+searcher.DocumentMatchPoolSize(), 0), + Collector: hc, + IndexReader: reader, + } + + dmHandlerMaker := MakeEligibleDocumentMatchHandler + if cv := ctx.Value(search.MakeDocumentMatchHandlerKey); cv != nil { + dmHandlerMaker = cv.(search.MakeDocumentMatchHandler) + } + // use the application given builder for making the custom document match + // handler and perform callbacks/invocations on the newly made handler. + dmHandler, _, err := dmHandlerMaker(searchContext) + if err != nil { + return err + } + select { + case <-ctx.Done(): + search.RecordSearchCost(ctx, search.AbortM, 0) + return ctx.Err() + default: + next, err = searcher.Next(searchContext) + } + for err == nil && next != nil { + if hc.total%CheckDoneEvery == 0 { + select { + case <-ctx.Done(): + search.RecordSearchCost(ctx, search.AbortM, 0) + return ctx.Err() + default: + } + } + hc.total++ + + err = dmHandler(next) + if err != nil { + break + } + + next, err = searcher.Next(searchContext) + } + if err != nil { + return err + } + + // help finalize/flush the results in case + // of custom document match handlers. + err = dmHandler(nil) + if err != nil { + return err + } + + // compute search duration + hc.took = time.Since(startTime) + + // finalize actual results + err = hc.finalizeResults(reader) + if err != nil { + return err + } + return nil +} + +func (hc *EligibleCollector) finalizeResults(r index.IndexReader) error { + var err error + hc.results, err = hc.store.Final(0, func(doc *search.DocumentMatch) error { + // Adding the results to the store without any modifications since we don't + // require the external ID of the filtered hits. + return nil + }) + return err +} + +func (hc *EligibleCollector) Results() search.DocumentMatchCollection { + return hc.results +} + +func (hc *EligibleCollector) Total() uint64 { + return hc.total +} + +// No concept of scoring in the eligible collector. +func (hc *EligibleCollector) MaxScore() float64 { + return 0 +} + +func (hc *EligibleCollector) Took() time.Duration { + return hc.took +} + +func (hc *EligibleCollector) SetFacetsBuilder(facetsBuilder *search.FacetsBuilder) { + // facet unsupported for pre-filtering in KNN search +} + +func (hc *EligibleCollector) FacetResults() search.FacetResults { + // facet unsupported for pre-filtering in KNN search + return nil +} diff --git a/search/collector/heap.go b/search/collector/heap.go index cd662bcf9..ab068b084 100644 --- a/search/collector/heap.go +++ b/search/collector/heap.go @@ -34,6 +34,11 @@ func newStoreHeap(capacity int, compare collectorCompare) *collectStoreHeap { return rv } +func (c *collectStoreHeap) Add(doc *search.DocumentMatch) *search.DocumentMatch { + c.add(doc) + return nil +} + func (c *collectStoreHeap) AddNotExceedingSize(doc *search.DocumentMatch, size int) *search.DocumentMatch { c.add(doc) diff --git a/search/collector/list.go b/search/collector/list.go index f73505e7d..b8b645199 100644 --- a/search/collector/list.go +++ b/search/collector/list.go @@ -34,6 +34,11 @@ func newStoreList(capacity int, compare collectorCompare) *collectStoreList { return rv } +func (c *collectStoreList) Add(doc *search.DocumentMatch, size int) *search.DocumentMatch { + c.results.PushBack(doc) + return nil +} + func (c *collectStoreList) AddNotExceedingSize(doc *search.DocumentMatch, size int) *search.DocumentMatch { c.add(doc) if c.len() > size { diff --git a/search/collector/slice.go b/search/collector/slice.go index 07534e693..03b212b0f 100644 --- a/search/collector/slice.go +++ b/search/collector/slice.go @@ -29,6 +29,11 @@ func newStoreSlice(capacity int, compare collectorCompare) *collectStoreSlice { return rv } +func (c *collectStoreSlice) Add(doc *search.DocumentMatch) *search.DocumentMatch { + c.slice = append(c.slice, doc) + return nil +} + func (c *collectStoreSlice) AddNotExceedingSize(doc *search.DocumentMatch, size int) *search.DocumentMatch { c.add(doc) diff --git a/search/collector/topn.go b/search/collector/topn.go index fc338f54e..5de473785 100644 --- a/search/collector/topn.go +++ b/search/collector/topn.go @@ -33,6 +33,10 @@ func init() { } type collectorStore interface { + // Adds a doc to the store without considering size. + // Returns nil if the doc was added successfully. + Add(doc *search.DocumentMatch) *search.DocumentMatch + // Add the document, and if the new store size exceeds the provided size // the last element is removed and returned. If the size has not been // exceeded, nil is returned. @@ -382,6 +386,27 @@ func (hc *TopNCollector) prepareDocumentMatch(ctx *search.SearchContext, return nil } +// Unlike TopNDocHandler, this will not eliminate docs based on score. +func MakeEligibleDocumentMatchHandler( + ctx *search.SearchContext) (search.DocumentMatchHandler, bool, error) { + + var hc *EligibleCollector + var ok bool + + if hc, ok = ctx.Collector.(*EligibleCollector); ok { + return func(d *search.DocumentMatch) error { + if d == nil { + return nil + } + + // No elements removed from the store here. + _ = hc.store.Add(d) + return nil + }, false, nil + } + return nil, false, nil +} + func MakeTopNDocumentMatchHandler( ctx *search.SearchContext) (search.DocumentMatchHandler, bool, error) { var hc *TopNCollector diff --git a/search/query/knn.go b/search/query/knn.go index 46eccb2a5..1493d4d2d 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -35,7 +35,9 @@ type KNNQuery struct { BoostVal *Boost `json:"boost,omitempty"` // see KNNRequest.Params for description - Params json.RawMessage `json:"params"` + Params json.RawMessage `json:"params"` + FilterQuery Query `json:"filter,omitempty"` + FilterResults []index.IndexInternalID } func NewKNNQuery(vector []float32) *KNNQuery { @@ -67,6 +69,10 @@ func (q *KNNQuery) SetParams(params json.RawMessage) { q.Params = params } +func (q *KNNQuery) SetFilterQuery(f Query) { + q.FilterQuery = f +} + func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping, options search.SearcherOptions) (search.Searcher, error) { fieldMapping := m.FieldMappingForPath(q.VectorField) @@ -81,6 +87,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, // normalize the vector q.Vector = mapping.NormalizeVector(q.Vector) } + return searcher.NewKNNSearcher(ctx, i, m, options, q.VectorField, - q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params) + q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params, q.FilterResults) } diff --git a/search/searcher/search_knn.go b/search/searcher/search_knn.go index e17bb7a0f..f4b72f449 100644 --- a/search/searcher/search_knn.go +++ b/search/searcher/search_knn.go @@ -49,11 +49,13 @@ type KNNSearcher struct { func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping, options search.SearcherOptions, field string, vector []float32, k int64, - boost float64, similarityMetric string, searchParams json.RawMessage) ( + boost float64, similarityMetric string, searchParams json.RawMessage, + filterIDs []index.IndexInternalID) ( search.Searcher, error) { if vr, ok := i.(index.VectorIndexReader); ok { - vectorReader, err := vr.VectorReader(ctx, vector, field, k, searchParams) + vectorReader, err := vr.VectorReader(ctx, vector, field, k, searchParams, + filterIDs) if err != nil { return nil, err } diff --git a/search_knn.go b/search_knn.go index 008a3615c..250b8c3ef 100644 --- a/search_knn.go +++ b/search_knn.go @@ -86,16 +86,19 @@ type KNNRequest struct { // - ivf_max_codes_pct : float // percentage of total vectors to visit to do a query (across all clusters) // // Consult go-faiss to know all supported search params - Params json.RawMessage `json:"params"` + Params json.RawMessage `json:"params"` + FilterQuery query.Query `JSON:"filter,omitempty"` } -func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64) { +func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost float64, + filterQuery query.Query) { b := query.Boost(boost) r.KNN = append(r.KNN, &KNNRequest{ - Field: field, - Vector: vector, - K: k, - Boost: &b, + Field: field, + Vector: vector, + K: k, + Boost: &b, + FilterQuery: filterQuery, }) } @@ -106,6 +109,15 @@ func (r *SearchRequest) AddKNNOperator(operator knnOperator) { // UnmarshalJSON deserializes a JSON representation of // a SearchRequest func (r *SearchRequest) UnmarshalJSON(input []byte) error { + type tempKNNReq struct { + Field string `json:"field"` + Vector []float32 `json:"vector"` + VectorBase64 string `json:"vector_base64"` + K int64 `json:"k"` + Boost *query.Boost `json:"boost,omitempty"` + FilterQuery json.RawMessage `JSON:"filter,omitempty"` + } + var temp struct { Q json.RawMessage `json:"query"` Size *int `json:"size"` @@ -119,7 +131,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { Score string `json:"score"` SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` - KNN []*KNNRequest `json:"knn"` + KNN []*tempKNNReq `json:"knn"` KNNOperator knnOperator `json:"knn_operator"` PreSearchData json.RawMessage `json:"pre_search_data"` } @@ -163,7 +175,14 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { r.From = 0 } - r.KNN = temp.KNN + for i, knnReq := range r.KNN { + knnReq.Field = temp.KNN[i].Field + knnReq.Vector = temp.KNN[i].Vector + knnReq.VectorBase64 = temp.KNN[i].VectorBase64 + knnReq.K = temp.KNN[i].K + knnReq.Boost = temp.KNN[i].Boost + knnReq.FilterQuery, err = query.ParseQuery(temp.KNN[i].FilterQuery) + } r.KNNOperator = temp.KNNOperator if r.KNNOperator == "" { r.KNNOperator = knnOperatorOr @@ -209,7 +228,8 @@ var ( knnOperatorOr = knnOperator("or") ) -func createKNNQuery(req *SearchRequest) (query.Query, []int64, int64, error) { +func createKNNQuery(req *SearchRequest, eligibleDocsMap map[int][]index.IndexInternalID) ( + query.Query, []int64, int64, error) { if requestHasKNN(req) { // first perform validation err := validateKNN(req) @@ -219,12 +239,16 @@ func createKNNQuery(req *SearchRequest) (query.Query, []int64, int64, error) { var subQueries []query.Query kArray := make([]int64, 0, len(req.KNN)) sumOfK := int64(0) - for _, knn := range req.KNN { + for i, knn := range req.KNN { knnQuery := query.NewKNNQuery(knn.Vector) knnQuery.SetFieldVal(knn.Field) knnQuery.SetK(knn.K) knnQuery.SetBoost(knn.Boost.Value()) knnQuery.SetParams(knn.Params) + knnQuery.SetFilterQuery(knn.FilterQuery) + if filterResults, exists := eligibleDocsMap[i]; exists { + knnQuery.FilterResults = filterResults + } subQueries = append(subQueries, knnQuery) kArray = append(kArray, knn.K) sumOfK += knn.K @@ -303,7 +327,46 @@ func addSortAndFieldsToKNNHits(req *SearchRequest, knnHits []*search.DocumentMat } func (i *indexImpl) runKnnCollector(ctx context.Context, req *SearchRequest, reader index.IndexReader, preSearch bool) ([]*search.DocumentMatch, error) { - KNNQuery, kArray, sumOfK, err := createKNNQuery(req) + // maps the index of the KNN query in the req to the pre-filter hits aka + // eligible docs' internal IDs . + filterHitsMap := make(map[int][]index.IndexInternalID) + + for idx, knnReq := range req.KNN { + // TODO Can use goroutines for this filter query stuff - do it if perf results + // show this to be significantly slow otherwise. + filterQ := knnReq.FilterQuery + + if _, ok := filterQ.(*query.MatchNoneQuery); ok { + // a match none query just means none the documents are eligible + // hence, we can save on running the query. + continue + } + // TODO See if running MatchAll queries can be skipped too + + // Applies to all supported types of queries. + filterSearcher, _ := filterQ.Searcher(ctx, reader, i.m, search.SearcherOptions{ + Score: "none", // just want eligible hits --> don't compute scores if not needed + }) + // Using the index doc count to determine collector size since we do not + // have an estimate of the number of eligible docs in the index yet. + indexDocCount, err := i.DocCount() + if err != nil { + return nil, err + } + filterColl := collector.NewEligibleCollector(int(indexDocCount)) + err = filterColl.Collect(ctx, filterSearcher, reader) + if err != nil { + return nil, err + } + filterHits := filterColl.Results() + filterHitsMap[idx] = make([]index.IndexInternalID, 0) + for _, docMatch := range filterHits { + filterHitsMap[idx] = append(filterHitsMap[idx], docMatch.IndexInternalID) + } + } + + // Add the filter hits when creating the kNN query + KNNQuery, kArray, sumOfK, err := createKNNQuery(req, filterHitsMap) if err != nil { return nil, err }