Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Dec 16, 2024
1 parent 4bbb6c0 commit c44162a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
50 changes: 39 additions & 11 deletions logics/non_personalized.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/zhenghaoz/gorse/storage/data"
"go.uber.org/zap"
"reflect"
"sort"
"time"
)

Expand All @@ -35,7 +36,8 @@ type NonPersonalized struct {
timestamp time.Time
scoreFunc *vm.Program
filterFunc *vm.Program
heap *heap.TopKFilter[cache.Score, float64]
heapSize int
heaps map[string]*heap.TopKFilter[string, float64]
}

func NewNonPersonalized(cfg config.NonPersonalizedConfig, n int, timestamp time.Time) (*NonPersonalized, error) {
Expand Down Expand Up @@ -66,12 +68,16 @@ func NewNonPersonalized(cfg config.NonPersonalizedConfig, n int, timestamp time.
return nil, errors.New("filter function must return bool")
}
}
// Initialize heap
heaps := make(map[string]*heap.TopKFilter[string, float64])
heaps[""] = heap.NewTopKFilter[string, float64](n)
return &NonPersonalized{
name: cfg.Name,
timestamp: timestamp,
scoreFunc: scoreFunc,
filterFunc: filterFunc,
heap: heap.NewTopKFilter[cache.Score, float64](n),
heapSize: n,
heaps: heaps,
}, nil
}

Expand Down Expand Up @@ -140,18 +146,40 @@ func (l *NonPersonalized) Push(item data.Item, feedback []data.Feedback) {
log.Logger().Error("score function must return float64", zap.Any("result", result))
return
}
l.heap.Push(cache.Score{
Id: item.ItemId,
Score: score,
IsHidden: item.IsHidden,
Categories: item.Categories,
Timestamp: l.timestamp,
}, score)
// Add to heap
l.heaps[""].Push(item.ItemId, score)
for _, group := range item.Categories {
if _, exist := l.heaps[group]; !exist {
l.heaps[group] = heap.NewTopKFilter[string, float64](l.heapSize)
}
l.heaps[group].Push(item.ItemId, score)
}
}

func (l *NonPersonalized) PopAll() []cache.Score {
scores, _ := l.heap.PopAll()
return scores
scores := make(map[string]*cache.Score)
for category, h := range l.heaps {
names, values := h.PopAll()
for i, name := range names {
if _, exist := scores[name]; !exist {
scores[name] = &cache.Score{
Id: name,
Score: values[i],
Categories: []string{category},
Timestamp: l.timestamp,
}
} else {
scores[name].Categories = append(scores[name].Categories, category)
}
}
}
result := lo.MapToSlice(scores, func(_ string, v *cache.Score) cache.Score {
return *v
})
sort.Slice(result, func(i, j int) bool {
return result[i].Score > result[j].Score
})
return result
}

func (l *NonPersonalized) Name() string {
Expand Down
9 changes: 9 additions & 0 deletions master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,15 @@ func (m *Master) LoadDataFromDatabase(
if item.IsHidden { // set hidden flag
rankingDataset.HiddenItems[itemIndex] = true
}
// TODO: Refactor
// add item to non-personalized recommenders
feedback, err := database.GetItemFeedback(newCtx, item.ItemId, posFeedbackTypes...)
if err != nil {
return nil, nil, errors.Trace(err)
}
for _, recommender := range nonPersonalizedRecommenders {
recommender.Push(item, feedback)
}
}
}
if err = <-errChan; err != nil {
Expand Down
3 changes: 2 additions & 1 deletion master/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ func (s *MasterTestSuite) TestLoadDataFromDatabase() {
UserId: strconv.Itoa(j),
FeedbackType: "positive",
},
Timestamp: time.Now(),
// TODO: Refactor
Timestamp: time.Now().Add(-time.Second),
})
}
// negative feedback
Expand Down

0 comments on commit c44162a

Please sign in to comment.