Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Faiss to the list of supported ANN frameworks #555

Merged
merged 6 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ One important aspect of deploying recommender model is efficient retrieval via A

| Supported framework | Cornac wrapper | Examples |
| :---: | :---: | :---: |
| [meta/faiss](https://github.com/facebookresearch/faiss) | [FaissANN](cornac/models/ann/recom_ann_faiss.py) | [ann_all.ipynb](examples/ann_all.ipynb)
| [nmslib/hnswlib](https://github.com/nmslib/hnswlib) | [HNSWLibANN](cornac/models/ann/recom_ann_hnswlib.py) | [ann_hnswlib.ipynb](tutorials/ann_hnswlib.ipynb), [ann_all.ipynb](examples/ann_all.ipynb)
| [google/scann](https://github.com/google-research/google-research/tree/master/scann) | [ScaNNANN](cornac/models/ann/recom_ann_scann.py) | [ann_all.ipynb](examples/ann_all.ipynb)

Expand Down
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .recommender import NextBasketRecommender

from .amr import AMR
from .ann import FaissANN
from .ann import HNSWLibANN
from .ann import ScaNNANN
from .baseline_only import BaselineOnly
Expand Down
1 change: 1 addition & 0 deletions cornac/models/ann/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .recom_ann_faiss import FaissANN
from .recom_ann_hnswlib import HNSWLibANN
from .recom_ann_scann import ScaNNANN
153 changes: 153 additions & 0 deletions cornac/models/ann/recom_ann_faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import multiprocessing
import numpy as np

from ..recommender import MEASURE_L2, MEASURE_DOT, MEASURE_COSINE
from .recom_ann_base import BaseANN


class FaissANN(BaseANN):
"""Approximate Nearest Neighbor Search with Faiss (https://github.com/facebookresearch/faiss).
Faiss provides both CPU and GPU implementation. More on the algorithms:
https://github.com/facebookresearch/faiss/wiki

Parameters
----------------
model: object: :obj:`cornac.models.Recommender`, required
Trained recommender model which to get user/item vectors from.

nlist: int, default: 100
The number of cells used for building the index.

nprobe: int, default: 50
The number of cells (out of nlist) that are visited to perform a search.

use_gpu : bool, optional
Whether or not to run Faiss on GPU. Requires faiss-gpu to be installed
instead of faiss-cpu.

num_threads: int, optional, default: -1
Default number of threads used for building index. If num_threads = -1,
all cores will be used.

seed: int, optional, default: None
Random seed for reproducibility.

name: str, required
Name of the recommender model.

verbose: boolean, optional, default: False
When True, running logs are displayed.
"""

def __init__(
self,
model,
nlist=100,
nprobe=50,
use_gpu=False,
num_threads=-1,
seed=None,
name="FaissANN",
verbose=False,
):
super().__init__(model=model, name=name, verbose=verbose)

self.model = model
self.nlist = nlist
self.nprobe = nprobe
self.use_gpu = use_gpu
self.num_threads = (
num_threads if num_threads != -1 else multiprocessing.cpu_count()
)
self.seed = seed

self.index = None
self.ignored_attrs.extend(
[
"index", # will be saved separately
"item_vectors", # redundant after index is built
]
)

def build_index(self):
"""Building index from the base recommender model."""
import faiss

faiss.omp_set_num_threads(self.num_threads)

SUPPORTED_MEASURES = {
MEASURE_L2: faiss.METRIC_L2,
MEASURE_DOT: faiss.METRIC_INNER_PRODUCT,
MEASURE_COSINE: faiss.METRIC_INNER_PRODUCT,
}

assert self.measure in SUPPORTED_MEASURES

if self.measure == MEASURE_COSINE:
self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[
:, np.newaxis
]

self.item_vectors = self.item_vectors.astype("float32")

self.index = faiss.IndexIVFFlat(
faiss.IndexFlat(self.item_vectors.shape[1]),
self.item_vectors.shape[1],
self.nlist,
SUPPORTED_MEASURES[self.measure],
)

if self.use_gpu:
self.index = faiss.index_cpu_to_all_gpus(self.index)

self.index.train(self.item_vectors)
self.index.add(self.item_vectors)
self.index.nprobe = self.nprobe

def knn_query(self, query, k):
"""Implementing ANN search for a given query.

Returns
-------
neighbors, distances: numpy.array and numpy.array
Array of k-nearest neighbors and corresponding distances for the given query.
"""
distances, neighbors = self.index.search(query, k)
return neighbors, distances

def save(self, save_dir=None):
import faiss

saved_path = super().save(save_dir)
idx_path = saved_path + ".index"
if self.use_gpu:
self.index = faiss.index_gpu_to_cpu(self.index)
faiss.write_index(self.index, idx_path)
return saved_path

@staticmethod
def load(model_path, trainable=False):
import faiss

ann = BaseANN.load(model_path, trainable)
idx_path = ann.load_from + ".index"
ann.index = faiss.read_index(idx_path)
if ann.use_gpu:
ann.index = faiss.index_cpu_to_all_gpus(ann.index)
return ann
4 changes: 2 additions & 2 deletions cornac/models/ann/recom_ann_hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def knn_query(self, query, k):

def save(self, save_dir=None):
saved_path = super().save(save_dir)
self.index.save_index(saved_path + ".idx")
self.index.save_index(saved_path + ".index")
return saved_path

@staticmethod
Expand All @@ -144,7 +144,7 @@ def load(model_path, trainable=False):
ann.index = hnswlib.Index(
space=SUPPORTED_MEASURES[ann.measure], dim=ann.user_vectors.shape[1]
)
ann.index.load_index(ann.load_from + ".idx")
ann.index.load_index(ann.load_from + ".index")
ann.index.set_ef(ann.ef)
ann.index.set_num_threads(ann.num_threads)
return ann
11 changes: 7 additions & 4 deletions cornac/models/ann/recom_ann_scann.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from .recom_ann_base import BaseANN


SUPPORTED_MEASURES = {MEASURE_L2: "squared_l2", MEASURE_DOT: "dot_product"}
SUPPORTED_MEASURES = {
MEASURE_L2: "squared_l2",
MEASURE_DOT: "dot_product",
MEASURE_COSINE: "dot_product",
}


class ScaNNANN(BaseANN):
Expand Down Expand Up @@ -108,7 +112,6 @@ def build_index(self):
self.item_vectors /= np.linalg.norm(self.item_vectors, axis=1)[
:, np.newaxis
]
self.measure = MEASURE_DOT
else:
self.partition_params["spherical"] = False

Expand Down Expand Up @@ -149,7 +152,7 @@ def knn_query(self, query, k):

def save(self, save_dir=None):
saved_path = super().save(save_dir)
idx_path = saved_path + ".idx"
idx_path = saved_path + ".index"
os.makedirs(idx_path, exist_ok=True)
self.index.searcher.serialize(idx_path)
return saved_path
Expand All @@ -159,6 +162,6 @@ def load(model_path, trainable=False):
from scann.scann_ops.py import scann_ops_pybind

ann = BaseANN.load(model_path, trainable)
idx_path = ann.load_from + ".idx"
idx_path = ann.load_from + ".index"
ann.index = scann_ops_pybind.load_searcher(idx_path)
return ann
Loading
Loading