Skip to content

Commit

Permalink
Make ANNs compatible with Experiment (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg authored Dec 12, 2023
1 parent 3706c5e commit a60bb86
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 58 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ One important aspect of deploying recommender model is efficient retrieval via A

| Supported framework | Cornac wrapper | Examples |
| :---: | :---: | :---: |
| [spotify/annoy](https://github.com/spotify/annoy) | [AnnoyANN](cornac/models/ann/recom_ann_annoy.py) | [ann_all.ipynb](examples/ann_all.ipynb)
| [meta/faiss](https://github.com/facebookresearch/faiss) | [FaissANN](cornac/models/ann/recom_ann_faiss.py) | [ann_all.ipynb](examples/ann_all.ipynb)
| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb)
| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb)
| [spotify/annoy](https://github.com/spotify/annoy) | [AnnoyANN](cornac/models/ann/recom_ann_annoy.py) | [ann_example.py](examples/ann_example.py), [ann_all.ipynb](examples/ann_all.ipynb)
| [meta/faiss](https://github.com/facebookresearch/faiss) | [FaissANN](cornac/models/ann/recom_ann_faiss.py) | [ann_example.py](examples/ann_example.py), [ann_all.ipynb](examples/ann_all.ipynb)
| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_example.py](examples/ann_example.py), [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb)
| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_example.py](examples/ann_example.py), [ann_all.ipynb](examples/ann_all.ipynb)


## Models
Expand Down
6 changes: 5 additions & 1 deletion cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def ranking_eval(
if len(metrics) == 0:
return [], []

max_k = max(m.k for m in metrics)

avg_results = []
user_results = [{} for _ in enumerate(metrics)]

Expand Down Expand Up @@ -203,7 +205,9 @@ def pos_items(csr_row):
u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(user_idx, item_indices)
item_rank, item_scores = model.rank(
user_idx=user_idx, item_indices=item_indices, k=max_k
)

for i, mt in enumerate(metrics):
mt_score = mt.compute(
Expand Down
11 changes: 4 additions & 7 deletions cornac/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@
# limitations under the License.
# ============================================================================

class CornacException(Exception):
"""Exception base class to extend from

"""
class CornacException(Exception):
"""Exception base class to extend from"""

pass


class ScoreException(CornacException):
"""Exception raised in score function when facing unknowns
"""Exception raised in score function when facing unknowns"""

"""

pass
pass
2 changes: 1 addition & 1 deletion cornac/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def run(self):
if self.val_result is not None:
self.val_result.append(val_result)

if not isinstance(self.result, CVExperimentResult):
if self.save_dir and (not isinstance(self.result, CVExperimentResult)):
model.save(self.save_dir)

output = ""
Expand Down
12 changes: 10 additions & 2 deletions cornac/models/ann/recom_ann_annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(
):
super().__init__(model=model, name=name, verbose=verbose)

self.model = model
self.n_trees = n_trees
self.search_k = search_k
self.num_threads = num_threads
Expand All @@ -85,14 +84,18 @@ def __init__(

def build_index(self):
"""Building index from the base recommender model."""
super().build_index()

from annoy import AnnoyIndex

assert self.measure in SUPPORTED_MEASURES

self.index = AnnoyIndex(
self.item_vectors.shape[1], SUPPORTED_MEASURES[self.measure]
)
self.index.set_seed(self.seed)

if self.seed is not None:
self.index.set_seed(self.seed)

for i, v in enumerate(self.item_vectors):
self.index.add_item(i, v)
Expand All @@ -115,6 +118,11 @@ def knn_query(self, query, k):
]
neighbors = np.array([r[0] for r in result], dtype="int")
distances = np.array([r[1] for r in result], dtype="float32")

# make sure distances respect the notion of nearest neighbors (smaller is better)
if self.higher_is_better:
distances = 1.0 - distances

return neighbors, distances

def save(self, save_dir=None):
Expand Down
103 changes: 93 additions & 10 deletions cornac/models/ann/recom_ann_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# ============================================================================

import copy
import warnings
import numpy as np

from ..recommender import Recommender
from ..recommender import is_ann_supported
from ..recommender import MEASURE_DOT, MEASURE_COSINE


class BaseANN(Recommender):
Expand All @@ -41,20 +43,50 @@ def __init__(self, model, name="BaseANN", verbose=False):
if not is_ann_supported(model):
raise ValueError(f"{model.name} doesn't support ANN search")

# ANN required attributes
self.measure = copy.deepcopy(model.get_vector_measure())
self.user_vectors = copy.deepcopy(model.get_user_vectors())
self.item_vectors = copy.deepcopy(model.get_item_vectors())
self.model = model

# get basic attributes to be a proper recommender
super().fit(train_set=model.train_set, val_set=model.val_set)
self.ignored_attrs.append("model") # not to save the base model with ANN

def build_index(self):
"""Building index from the base recommender model.
if model.is_fitted:
Recommender.fit(self, model.train_set, model.val_set)

:raise NotImplementedError
def fit(self, train_set, val_set=None):
"""Fit the model to observations.
Parameters
----------
train_set: :obj:`cornac.data.Dataset`, required
User-Item preference data as well as additional modalities.
val_set: :obj:`cornac.data.Dataset`, optional, default: None
User-Item preference data for model selection purposes (e.g., early stopping).
Returns
-------
self : object
"""
raise NotImplementedError()
Recommender.fit(self, train_set, val_set)

if not self.model.is_fitted:
if self.verbose:
print(f"Fitting base recommender model {self.model.name}...")
self.model.fit(train_set, val_set)

self.build_index()

return self

def build_index(self):
"""Building index from the base recommender model."""
if not self.model.is_fitted:
warnings.warn(f"Base recommender model {self.model.name} is not fitted!")

# ANN required attributes
self.measure = copy.deepcopy(self.model.get_vector_measure())
self.user_vectors = copy.deepcopy(self.model.get_user_vectors())
self.item_vectors = copy.deepcopy(self.model.get_item_vectors())

self.higher_is_better = self.measure in {MEASURE_DOT, MEASURE_COSINE}

def knn_query(self, query, k):
"""Implementing ANN search for a given query.
Expand All @@ -65,6 +97,57 @@ def knn_query(self, query, k):
"""
raise NotImplementedError()

def rank(self, user_idx, item_indices=None, k=-1, **kwargs):
"""Rank all test items for a given user.
Parameters
----------
user_idx: int, required
The index of the user for whom to perform item raking.
item_indices: 1d array, optional, default: None
A list of candidate item indices to be ranked by the user.
If `None`, list of ranked known item indices and their scores will be returned.
k: int, required
Cut-off length for recommendations, k=-1 will return ranked list of all items.
Returns
-------
(ranked_items, item_scores): tuple
`ranked_items` contains item indices being ranked by their scores.
`item_scores` contains scores of items corresponding to index in `item_indices` input.
"""
query = self.user_vectors[[user_idx]]
knn_items, distances = self.knn_query(query, k=k)

top_k_items = knn_items[0]
top_k_scores = -distances[0]

item_scores = np.full(self.total_items, -np.Inf)
item_scores[top_k_items] = top_k_scores

all_items = np.arange(self.total_items)
ranked_items = np.concatenate(
[
top_k_items,
all_items[~np.isin(all_items, top_k_items, assume_unique=True)],
]
)

# rank items based on their scores
if item_indices is None:
item_scores = item_scores[: self.num_items]
ranked_items = ranked_items[: self.num_items]
else:
item_scores = item_scores[item_indices]
ranked_items = ranked_items[
np.isin(ranked_items, item_indices, assume_unique=True)
]

return ranked_items, item_scores

def recommend(self, user_id, k=-1, remove_seen=False, train_set=None):
"""Generate top-K item recommendations for a given user. Backward compatibility.
Expand Down
8 changes: 7 additions & 1 deletion cornac/models/ann/recom_ann_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(
):
super().__init__(model=model, name=name, verbose=verbose)

self.model = model
self.nlist = nlist
self.nprobe = nprobe
self.use_gpu = use_gpu
Expand All @@ -87,6 +86,8 @@ def __init__(

def build_index(self):
"""Building index from the base recommender model."""
super().build_index()

import faiss

faiss.omp_set_num_threads(self.num_threads)
Expand Down Expand Up @@ -129,6 +130,11 @@ def knn_query(self, query, k):
Array of k-nearest neighbors and corresponding distances for the given query.
"""
distances, neighbors = self.index.search(query, k)

# make sure distances respect the notion of nearest neighbors (smaller is better)
if self.higher_is_better:
distances = 1.0 - distances

return neighbors, distances

def save(self, save_dir=None):
Expand Down
3 changes: 3 additions & 0 deletions cornac/models/ann/recom_ann_hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
verbose=False,
):
super().__init__(model=model, name=name, verbose=verbose)

self.M = M
self.ef_construction = ef_construction
self.ef = ef
Expand All @@ -96,6 +97,8 @@ def __init__(

def build_index(self):
"""Building index from the base recommender model."""
super().build_index()

import hnswlib

assert self.measure in SUPPORTED_MEASURES
Expand Down
19 changes: 17 additions & 2 deletions cornac/models/ann/recom_ann_scann.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,18 @@ def __init__(
):
super().__init__(model=model, name=name, verbose=verbose)

if partition_params is None:
partition_params = {"num_leaves": 100, "num_leaves_to_search": 50}

if score_params is None:
score_params = {}
score_params = {
"dimensions_per_block": 2,
"anisotropic_quantization_threshold": 0.2,
}

if rescore_params is None:
rescore_params = {"reordering_num_neighbors": 100}

self.model = model
self.partition_params = partition_params
self.score_params = score_params
self.score_brute_force = score_brute_force
Expand All @@ -103,6 +111,8 @@ def __init__(

def build_index(self):
"""Building index from the base recommender model."""
super().build_index()

import scann

assert self.measure in SUPPORTED_MEASURES
Expand Down Expand Up @@ -148,6 +158,11 @@ def knn_query(self, query, k):
Array of k-nearest neighbors and corresponding distances for the given query.
"""
neighbors, distances = self.index.search_batched(query, final_num_neighbors=k)

# make sure distances respect the notion of nearest neighbors (smaller is better)
if self.higher_is_better:
distances = 1.0 - distances

return neighbors, distances

def save(self, save_dir=None):
Expand Down
Loading

0 comments on commit a60bb86

Please sign in to comment.