Skip to content

Commit

Permalink
MB-61029: Added ref counts to track number of live queries using the …
Browse files Browse the repository at this point in the history
…cache

 - This prevents accidental clearing of cache entries while in use
  • Loading branch information
Likith101 committed Apr 8, 2024
1 parent eb696f4 commit 8376458
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
50 changes: 41 additions & 9 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ type ewma struct {
type cacheEntry struct {
tracker *ewma

m sync.RWMutex
index *faiss.IndexImpl
m sync.RWMutex
useCount int64
index *faiss.IndexImpl
}

func newVectorIndexCache() *vecIndexCache {
Expand All @@ -63,13 +64,14 @@ func (vc *vecIndexCache) loadVectorIndex(fieldID uint16,
cachedIndex, present := vc.isIndexCached(fieldID)
if present {
vecIndex = cachedIndex
vc.addRef(fieldID)
vc.incHit(fieldID)
} else {
// if the cache doesn't have vector index, just construct it out of the
// index bytes and update the cache.
vecIndex, err = faiss.ReadIndexFromBuffer(indexBytes, faiss.IOFlagReadOnly)
vc.update(fieldID, vecIndex)
}
vc.addRef(fieldID)
return vecIndex, err
}

Expand All @@ -88,6 +90,14 @@ func (vc *vecIndexCache) isIndexCached(fieldID uint16) (*faiss.IndexImpl, bool)
return rv, present && (rv != nil)
}

func (vc *vecIndexCache) incHit(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
vc.m.RUnlock()

entry.incHit()
}

func (vc *vecIndexCache) addRef(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
Expand All @@ -96,6 +106,14 @@ func (vc *vecIndexCache) addRef(fieldIDPlus1 uint16) {
entry.addRef()
}

func (vc *vecIndexCache) decRef(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
vc.m.RUnlock()

entry.decRef()
}

func (vc *vecIndexCache) refresh() (rv int) {
vc.m.Lock()
cache := vc.cache
Expand All @@ -104,15 +122,21 @@ func (vc *vecIndexCache) refresh() (rv int) {
for fieldIDPlus1, entry := range cache {
sample := atomic.LoadUint64(&entry.tracker.sample)
entry.tracker.add(sample)

refCount := atomic.LoadInt64(&entry.useCount)
// the comparison threshold as of now is (1 - a). mathematically it
// means that there is only 1 query per second on average as per history.
// and in the current second, there were no queries performed against
// this index.
if entry.tracker.avg <= (1 - entry.tracker.alpha) {
atomic.StoreUint64(&entry.tracker.sample, 0)
entry.closeIndex()
delete(vc.cache, fieldIDPlus1)
continue
if entry.tracker.avg <= (1-entry.tracker.alpha) && refCount <= 0 {
if refCount == 0 {
atomic.StoreUint64(&entry.tracker.sample, 0)
entry.closeIndex()
delete(vc.cache, fieldIDPlus1)
continue
} else {
atomic.StoreUint64(&entry.tracker.sample, 0)
}
}
atomic.StoreUint64(&entry.tracker.sample, 0)
}
Expand Down Expand Up @@ -180,13 +204,21 @@ func initCacheEntry(index *faiss.IndexImpl, alpha float64) *cacheEntry {
return vc
}

func (vc *cacheEntry) addRef() {
func (vc *cacheEntry) incHit() {
// every access to the cache entry is accumulated as part of a sample
// which will be used to calculate the average in the next cycle of average
// computation
atomic.AddUint64(&vc.tracker.sample, 1)
}

func (vc *cacheEntry) addRef() {
atomic.AddInt64(&vc.useCount, 1)
}

func (vc *cacheEntry) decRef() {
atomic.AddInt64(&vc.useCount, -1)
}

func (vc *cacheEntry) closeIndex() {
vc.m.Lock()
vc.index.Close()
Expand Down
4 changes: 3 additions & 1 deletion faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
var vecIndex *faiss.IndexImpl
vecDocIDMap := make(map[int64]uint32)
var vectorIDsToExclude []int64
var fieldIDPlus1 uint16

var (
wrapVecIndex = &vectorIndexWrapper{
Expand Down Expand Up @@ -337,6 +338,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
close: func() {
// skipping the closing because the index is cached and it's being
// deferred to a later point of time.
sb.vectorCache.decRef(fieldIDPlus1)
},
size: func() uint64 {
if vecIndex != nil {
Expand All @@ -349,7 +351,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
err error
)

fieldIDPlus1 := sb.fieldsMap[field]
fieldIDPlus1 = sb.fieldsMap[field]
if fieldIDPlus1 <= 0 {
return wrapVecIndex, nil
}
Expand Down

0 comments on commit 8376458

Please sign in to comment.