Skip to content

Commit

Permalink
MB-62230 - Pre-Filtering Support for kNN
Browse files Browse the repository at this point in the history
  • Loading branch information
metonymic-smokey committed Aug 12, 2024
1 parent f7fea09 commit c332538
Show file tree
Hide file tree
Showing 12 changed files with 340 additions and 25 deletions.
20 changes: 19 additions & 1 deletion index/scorch/optimize_knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"sync"
"sync/atomic"

"github.com/RoaringBitmap/roaring"
"github.com/blevesearch/bleve/v2/search"
index "github.com/blevesearch/bleve_index_api"
segment_api "github.com/blevesearch/scorch_segment_api/v2"
Expand Down Expand Up @@ -62,6 +63,8 @@ func (o *OptimizeVR) Finish() error {
var errorsM sync.Mutex
var errors []error

snapshotGlobalDocNums := o.snapshot.globalDocNums()

defer o.invokeSearcherEndCallback()

wg := sync.WaitGroup{}
Expand Down Expand Up @@ -89,9 +92,24 @@ func (o *OptimizeVR) Finish() error {
vectorIndexSize := vecIndex.Size()
origSeg.cachedMeta.updateMeta(field, vectorIndexSize)
for _, vr := range vrs {
eligibleVectorInternalIDs := vr.EligibleDocIDs()
// Only the eligible documents belonging to this segment
// will get filtered out.
// There is no way to determine which doc belongs to which segment
eligibleVectorInternalIDs.And(snapshotGlobalDocNums[index])

eligibleLocalDocNums := roaring.NewBitmap()
// get the (segment-)local document numbers
for _, docNum := range eligibleVectorInternalIDs.ToArray() {
localDocNum := o.snapshot.localDocNumFromGlobal(index,
uint64(docNum))
eligibleLocalDocNums.Add(uint32(localDocNum))
}

// for each VR, populate postings list and iterators
// by passing the obtained vector index and getting similar vectors.
pl, err := vecIndex.Search(vr.vector, vr.k, vr.searchParams)
pl, err := vecIndex.Search(vr.vector, vr.k,
eligibleLocalDocNums, vr.searchParams)
if err != nil {
errorsM.Lock()
errors = append(errors, err)
Expand Down
21 changes: 20 additions & 1 deletion index/scorch/snapshot_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,35 @@ func (is *IndexSnapshot) Document(id string) (rv index.Document, err error) {
return rvd, nil
}

func (is *IndexSnapshot) localDocNumFromGlobal(segmentIndex int, docNum uint64) uint64 {
return docNum - is.offsets[segmentIndex]
}

func (is *IndexSnapshot) segmentIndexAndLocalDocNumFromGlobal(docNum uint64) (int, uint64) {
segmentIndex := sort.Search(len(is.offsets),
func(x int) bool {
return is.offsets[x] > docNum
}) - 1

localDocNum := docNum - is.offsets[segmentIndex]
localDocNum := is.localDocNumFromGlobal(segmentIndex, docNum)
return int(segmentIndex), localDocNum
}

// function to get return a mapping of the segment index to the live global
//
// doc nums in the segment at the specified index snapshot.
func (is *IndexSnapshot) globalDocNums() map[int]*roaring.Bitmap {
segmentIndexGlobalDocNums := make(map[int]*roaring.Bitmap)

for i := range is.segment {
segmentIndexGlobalDocNums[i] = roaring.NewBitmap()
for _, localDocNum := range is.segment[i].DocNumbersLive().ToArray() {
segmentIndexGlobalDocNums[i].Add(localDocNum + uint32(is.offsets[i]))
}
}
return segmentIndexGlobalDocNums
}

func (is *IndexSnapshot) ExternalID(id index.IndexInternalID) (string, error) {
docNum, err := docInternalToNumber(id)
if err != nil {
Expand Down
17 changes: 15 additions & 2 deletions index/scorch/snapshot_index_vr.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"reflect"

"github.com/RoaringBitmap/roaring"
"github.com/blevesearch/bleve/v2/size"
index "github.com/blevesearch/bleve_index_api"
segment_api "github.com/blevesearch/scorch_segment_api/v2"
Expand All @@ -50,7 +51,18 @@ type IndexSnapshotVectorReader struct {
currID index.IndexInternalID
ctx context.Context

searchParams json.RawMessage
searchParams json.RawMessage
eligibleDocIDs []index.IndexInternalID
}

func (i *IndexSnapshotVectorReader) EligibleDocIDs() *roaring.Bitmap {
res := roaring.NewBitmap()
// converts the doc IDs to uint32 and returns
for _, eligibleDocInternalID := range i.eligibleDocIDs {
internalDocID, _ := docInternalToNumber(index.IndexInternalID(eligibleDocInternalID))
res.Add(uint32(internalDocID))
}
return res
}

func (i *IndexSnapshotVectorReader) Size() int {
Expand Down Expand Up @@ -108,7 +120,8 @@ func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID,
preAlloced *index.VectorDoc) (*index.VectorDoc, error) {

if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 {
i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k, i.searchParams)
i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k,
i.searchParams, i.eligibleDocIDs)
if err != nil {
return nil, err
}
Expand Down
13 changes: 7 additions & 6 deletions index/scorch/snapshot_vector_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ import (
)

func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32,
field string, k int64, searchParams json.RawMessage) (
field string, k int64, searchParams json.RawMessage, filterIDs []index.IndexInternalID) (
index.VectorReader, error) {

rv := &IndexSnapshotVectorReader{
vector: vector,
field: field,
k: k,
snapshot: is,
searchParams: searchParams,
vector: vector,
field: field,
k: k,
snapshot: is,
searchParams: searchParams,
eligibleDocIDs: filterIDs,
}

if rv.postings == nil {
Expand Down
152 changes: 152 additions & 0 deletions search/collector/eligible.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package collector

import (
"context"
"time"

"github.com/blevesearch/bleve/v2/search"
index "github.com/blevesearch/bleve_index_api"
)

type EligibleCollector struct {
size int
total uint64
maxScore float64
took time.Duration
results search.DocumentMatchCollection

store collectorStore

needDocIds bool
neededFields []string
cachedDesc []bool

lowestMatchOutsideResults *search.DocumentMatch
updateFieldVisitor index.DocValueVisitor
dvReader index.DocValueReader
searchAfter *search.DocumentMatch
}

func NewEligibleCollector(size int) *EligibleCollector {
return newEligibleCollector(size)
}

func newEligibleCollector(size int) *EligibleCollector {
// No sort order & skip always 0 since this is only to filter eligible docs.
hc := &EligibleCollector{size: size}

// comparator is a dummy here
hc.store = getOptimalCollectorStore(size, 0, func(i, j *search.DocumentMatch) int {
return 0
})

return hc
}

func (hc *EligibleCollector) Collect(ctx context.Context, searcher search.Searcher, reader index.IndexReader) error {
startTime := time.Now()
var err error
var next *search.DocumentMatch

backingSize := hc.size
if backingSize > PreAllocSizeSkipCap {
backingSize = PreAllocSizeSkipCap + 1
}
searchContext := &search.SearchContext{
DocumentMatchPool: search.NewDocumentMatchPool(backingSize+searcher.DocumentMatchPoolSize(), 0),
Collector: hc,
IndexReader: reader,
}

dmHandlerMaker := MakeEligibleDocumentMatchHandler
if cv := ctx.Value(search.MakeDocumentMatchHandlerKey); cv != nil {
dmHandlerMaker = cv.(search.MakeDocumentMatchHandler)
}
// use the application given builder for making the custom document match
// handler and perform callbacks/invocations on the newly made handler.
dmHandler, _, err := dmHandlerMaker(searchContext)
if err != nil {
return err
}
select {
case <-ctx.Done():
search.RecordSearchCost(ctx, search.AbortM, 0)
return ctx.Err()
default:
next, err = searcher.Next(searchContext)
}
for err == nil && next != nil {
if hc.total%CheckDoneEvery == 0 {
select {
case <-ctx.Done():
search.RecordSearchCost(ctx, search.AbortM, 0)
return ctx.Err()
default:
}
}
hc.total++

err = dmHandler(next)
if err != nil {
break
}

next, err = searcher.Next(searchContext)
}
if err != nil {
return err
}

// help finalize/flush the results in case
// of custom document match handlers.
err = dmHandler(nil)
if err != nil {
return err
}

// compute search duration
hc.took = time.Since(startTime)

// finalize actual results
err = hc.finalizeResults(reader)
if err != nil {
return err
}
return nil
}

func (hc *EligibleCollector) finalizeResults(r index.IndexReader) error {
var err error
hc.results, err = hc.store.Final(0, func(doc *search.DocumentMatch) error {
// Adding the results to the store without any modifications since we don't
// require the external ID of the filtered hits.
return nil
})
return err
}

func (hc *EligibleCollector) Results() search.DocumentMatchCollection {
return hc.results
}

func (hc *EligibleCollector) Total() uint64 {
return hc.total
}

// No concept of scoring in the eligible collector.
func (hc *EligibleCollector) MaxScore() float64 {
return 0
}

func (hc *EligibleCollector) Took() time.Duration {
return hc.took
}

func (hc *EligibleCollector) SetFacetsBuilder(facetsBuilder *search.FacetsBuilder) {
// facet unsupported for pre-filtering in KNN search
}

func (hc *EligibleCollector) FacetResults() search.FacetResults {
// facet unsupported for pre-filtering in KNN search
return nil
}
5 changes: 5 additions & 0 deletions search/collector/heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func newStoreHeap(capacity int, compare collectorCompare) *collectStoreHeap {
return rv
}

func (c *collectStoreHeap) Add(doc *search.DocumentMatch) *search.DocumentMatch {
c.add(doc)
return nil
}

func (c *collectStoreHeap) AddNotExceedingSize(doc *search.DocumentMatch,
size int) *search.DocumentMatch {
c.add(doc)
Expand Down
5 changes: 5 additions & 0 deletions search/collector/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func newStoreList(capacity int, compare collectorCompare) *collectStoreList {
return rv
}

func (c *collectStoreList) Add(doc *search.DocumentMatch, size int) *search.DocumentMatch {
c.results.PushBack(doc)
return nil
}

func (c *collectStoreList) AddNotExceedingSize(doc *search.DocumentMatch, size int) *search.DocumentMatch {
c.add(doc)
if c.len() > size {
Expand Down
5 changes: 5 additions & 0 deletions search/collector/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func newStoreSlice(capacity int, compare collectorCompare) *collectStoreSlice {
return rv
}

func (c *collectStoreSlice) Add(doc *search.DocumentMatch) *search.DocumentMatch {
c.slice = append(c.slice, doc)
return nil
}

func (c *collectStoreSlice) AddNotExceedingSize(doc *search.DocumentMatch,
size int) *search.DocumentMatch {
c.add(doc)
Expand Down
25 changes: 25 additions & 0 deletions search/collector/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ func init() {
}

type collectorStore interface {
// Adds a doc to the store without considering size.
// Returns nil if the doc was added successfully.
Add(doc *search.DocumentMatch) *search.DocumentMatch

// Add the document, and if the new store size exceeds the provided size
// the last element is removed and returned. If the size has not been
// exceeded, nil is returned.
Expand Down Expand Up @@ -382,6 +386,27 @@ func (hc *TopNCollector) prepareDocumentMatch(ctx *search.SearchContext,
return nil
}

// Unlike TopNDocHandler, this will not eliminate docs based on score.
func MakeEligibleDocumentMatchHandler(
ctx *search.SearchContext) (search.DocumentMatchHandler, bool, error) {

var hc *EligibleCollector
var ok bool

if hc, ok = ctx.Collector.(*EligibleCollector); ok {
return func(d *search.DocumentMatch) error {
if d == nil {
return nil
}

// No elements removed from the store here.
_ = hc.store.Add(d)
return nil
}, false, nil
}
return nil, false, nil
}

func MakeTopNDocumentMatchHandler(
ctx *search.SearchContext) (search.DocumentMatchHandler, bool, error) {
var hc *TopNCollector
Expand Down
11 changes: 9 additions & 2 deletions search/query/knn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ type KNNQuery struct {
BoostVal *Boost `json:"boost,omitempty"`

// see KNNRequest.Params for description
Params json.RawMessage `json:"params"`
Params json.RawMessage `json:"params"`
FilterQuery Query `json:"filter,omitempty"`
FilterResults []index.IndexInternalID
}

func NewKNNQuery(vector []float32) *KNNQuery {
Expand Down Expand Up @@ -67,6 +69,10 @@ func (q *KNNQuery) SetParams(params json.RawMessage) {
q.Params = params
}

func (q *KNNQuery) SetFilterQuery(f Query) {
q.FilterQuery = f
}

func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
m mapping.IndexMapping, options search.SearcherOptions) (search.Searcher, error) {
fieldMapping := m.FieldMappingForPath(q.VectorField)
Expand All @@ -81,6 +87,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
// normalize the vector
q.Vector = mapping.NormalizeVector(q.Vector)
}

return searcher.NewKNNSearcher(ctx, i, m, options, q.VectorField,
q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params)
q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params, q.FilterResults)
}
Loading

0 comments on commit c332538

Please sign in to comment.