Skip to content

Commit c73c827

Browse files
resolved TODOs
1 parent c1ecebc commit c73c827

File tree

6 files changed

+42
-27
lines changed

6 files changed

+42
-27
lines changed

index/scorch/optimize_knn.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func (o *OptimizeVR) Finish() error {
109109
// for each VR, populate postings list and iterators
110110
// by passing the obtained vector index and getting similar vectors.
111111
pl, err := vecIndex.Search(vr.vector, vr.k,
112-
eligibleLocalDocNums, vr.searchParams)
112+
vr.requireFiltering, eligibleLocalDocNums, vr.searchParams)
113113
if err != nil {
114114
errorsM.Lock()
115115
errors = append(errors, err)

index/scorch/snapshot_index_vr.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ type IndexSnapshotVectorReader struct {
5151
currID index.IndexInternalID
5252
ctx context.Context
5353

54-
searchParams json.RawMessage
55-
eligibleDocIDs []index.IndexInternalID
54+
searchParams json.RawMessage
55+
eligibleDocIDs []index.IndexInternalID
56+
requireFiltering bool
5657
}
5758

5859
func (i *IndexSnapshotVectorReader) EligibleDocIDs() *roaring.Bitmap {
@@ -121,7 +122,7 @@ func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID,
121122

122123
if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 {
123124
i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k,
124-
i.searchParams, i.eligibleDocIDs)
125+
i.searchParams, i.eligibleDocIDs, i.requireFiltering)
125126
if err != nil {
126127
return nil, err
127128
}

index/scorch/snapshot_vector_index.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@ import (
2626
)
2727

2828
func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32,
29-
field string, k int64, searchParams json.RawMessage, filterIDs []index.IndexInternalID) (
29+
field string, k int64, searchParams json.RawMessage,
30+
filterIDs []index.IndexInternalID, requireFiltering bool) (
3031
index.VectorReader, error) {
3132

3233
rv := &IndexSnapshotVectorReader{
33-
vector: vector,
34-
field: field,
35-
k: k,
36-
snapshot: is,
37-
searchParams: searchParams,
38-
eligibleDocIDs: filterIDs,
34+
vector: vector,
35+
field: field,
36+
k: k,
37+
snapshot: is,
38+
searchParams: searchParams,
39+
eligibleDocIDs: filterIDs,
40+
requireFiltering: requireFiltering,
3941
}
4042

4143
if rv.postings == nil {

search/query/knn.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ type KNNQuery struct {
3535
BoostVal *Boost `json:"boost,omitempty"`
3636

3737
// see KNNRequest.Params for description
38-
Params json.RawMessage `json:"params"`
39-
FilterQuery Query `json:"filter,omitempty"`
40-
FilterResults []index.IndexInternalID
38+
Params json.RawMessage `json:"params"`
39+
FilterQuery Query `json:"filter,omitempty"`
40+
FilterResults []index.IndexInternalID
41+
RequireFiltering bool
4142
}
4243

4344
func NewKNNQuery(vector []float32) *KNNQuery {
@@ -69,8 +70,9 @@ func (q *KNNQuery) SetParams(params json.RawMessage) {
6970
q.Params = params
7071
}
7172

72-
func (q *KNNQuery) SetFilterQuery(f Query) {
73+
func (q *KNNQuery) SetFilterQuery(f Query, requireFiltering bool) {
7374
q.FilterQuery = f
75+
q.RequireFiltering = requireFiltering
7476
}
7577

7678
func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
@@ -89,5 +91,6 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader,
8991
}
9092

9193
return searcher.NewKNNSearcher(ctx, i, m, options, q.VectorField,
92-
q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params, q.FilterResults)
94+
q.Vector, q.K, q.BoostVal.Value(), similarityMetric, q.Params,
95+
q.FilterResults, q.RequireFiltering)
9396
}

search/searcher/search_knn.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ type KNNSearcher struct {
5050
func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping,
5151
options search.SearcherOptions, field string, vector []float32, k int64,
5252
boost float64, similarityMetric string, searchParams json.RawMessage,
53-
filterIDs []index.IndexInternalID) (
53+
filterIDs []index.IndexInternalID, requireFiltering bool) (
5454
search.Searcher, error) {
5555

5656
if vr, ok := i.(index.VectorIndexReader); ok {
5757
vectorReader, err := vr.VectorReader(ctx, vector, field, k, searchParams,
58-
filterIDs)
58+
filterIDs, requireFiltering)
5959
if err != nil {
6060
return nil, err
6161
}

search_knn.go

+18-9
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ var (
228228
knnOperatorOr = knnOperator("or")
229229
)
230230

231-
func createKNNQuery(req *SearchRequest, eligibleDocsMap map[int][]index.IndexInternalID) (
231+
func createKNNQuery(req *SearchRequest, eligibleDocsMap map[int][]index.IndexInternalID,
232+
requiresFiltering map[int]bool) (
232233
query.Query, []int64, int64, error) {
233234
if requestHasKNN(req) {
234235
// first perform validation
@@ -245,8 +246,9 @@ func createKNNQuery(req *SearchRequest, eligibleDocsMap map[int][]index.IndexInt
245246
knnQuery.SetK(knn.K)
246247
knnQuery.SetBoost(knn.Boost.Value())
247248
knnQuery.SetParams(knn.Params)
248-
knnQuery.SetFilterQuery(knn.FilterQuery)
249-
if filterResults, exists := eligibleDocsMap[i]; exists {
249+
knnQuery.SetFilterQuery(knn.FilterQuery, requiresFiltering[i])
250+
filterResults, exists := eligibleDocsMap[i]
251+
if exists && requiresFiltering[i] {
250252
knnQuery.FilterResults = filterResults
251253
}
252254
subQueries = append(subQueries, knnQuery)
@@ -330,20 +332,26 @@ func (i *indexImpl) runKnnCollector(ctx context.Context, req *SearchRequest, rea
330332
// maps the index of the KNN query in the req to the pre-filter hits aka
331333
// eligible docs' internal IDs .
332334
filterHitsMap := make(map[int][]index.IndexInternalID)
335+
// Indicates if this query requires filtering downstream
336+
// No filtering required if it's a match all query/no filters applied.
337+
requiresFiltering := make(map[int]bool)
333338

334339
for idx, knnReq := range req.KNN {
335340
// TODO Can use goroutines for this filter query stuff - do it if perf results
336341
// show this to be significantly slow otherwise.
337342
filterQ := knnReq.FilterQuery
338-
// If there are no filters here, add a match all since that will ensure that
339-
// all the live docs in the index are eligible.
340-
// TODO See if running MatchAll queries can be skipped too - if perf shows
341-
// them to be time-consuming in existing kNN tests?
342343
if filterQ == nil {
343-
filterQ = query.NewMatchAllQuery()
344+
requiresFiltering[idx] = false
345+
}
346+
347+
if _, ok := filterQ.(*query.MatchAllQuery); ok {
348+
requiresFiltering[idx] = false
349+
continue
344350
}
345351

346352
if _, ok := filterQ.(*query.MatchNoneQuery); ok {
353+
// Filtering required since no hits are eligible.
354+
requiresFiltering[idx] = true
347355
// a match none query just means none the documents are eligible
348356
// hence, we can save on running the query.
349357
continue
@@ -369,10 +377,11 @@ func (i *indexImpl) runKnnCollector(ctx context.Context, req *SearchRequest, rea
369377
for _, docMatch := range filterHits {
370378
filterHitsMap[idx] = append(filterHitsMap[idx], docMatch.IndexInternalID)
371379
}
380+
requiresFiltering[idx] = true
372381
}
373382

374383
// Add the filter hits when creating the kNN query
375-
KNNQuery, kArray, sumOfK, err := createKNNQuery(req, filterHitsMap)
384+
KNNQuery, kArray, sumOfK, err := createKNNQuery(req, filterHitsMap, requiresFiltering)
376385
if err != nil {
377386
return nil, err
378387
}

0 commit comments

Comments
 (0)