-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added save/load functionality to AnnoyIndexer #845
Changes from 4 commits
e0af7b1
d92e501
61d3ef1
8ce445e
7df7d1e
31b5e53
3a546ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,11 @@ | |
# | ||
# Copyright (C) 2013 Radim Rehurek <me@radimrehurek.com> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
import os | ||
try: | ||
import cPickle as pickle | ||
except ImportError: | ||
import pickle | ||
|
||
from gensim.models.doc2vec import Doc2Vec | ||
from gensim.models.word2vec import Word2Vec | ||
|
@@ -15,16 +19,32 @@ | |
|
||
class AnnoyIndexer(object): | ||
|
||
def __init__(self, model, num_trees): | ||
def __init__(self, model=None, num_trees=None): | ||
self.index = None | ||
self.labels = None | ||
self.model = model | ||
self.num_trees = num_trees | ||
|
||
if isinstance(self.model, Doc2Vec): | ||
self.build_from_doc2vec() | ||
elif isinstance(self.model, Word2Vec): | ||
self.build_from_word2vec() | ||
else: | ||
raise ValueError("Only a Word2Vec or Doc2Vec instance can be used") | ||
if model and num_trees: | ||
if isinstance(self.model, Doc2Vec): | ||
self.build_from_doc2vec() | ||
elif isinstance(self.model, Word2Vec): | ||
self.build_from_word2vec() | ||
else: | ||
raise ValueError("Only a Word2Vec or Doc2Vec instance can be used") | ||
|
||
def save(self, fname): | ||
self.index.save(fname) | ||
d = {'f': self.model.vector_size, 'num_trees': self.num_trees, 'labels': self.labels} | ||
pickle.dump(d, open(fname+'.d', 'wb'), 2) | ||
|
||
def load(self, fname): | ||
if os.path.exists(fname) and os.path.exists(fname+'.d'): | ||
d = pickle.load(open(fname+'.d', 'rb')) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use smart_open as in https://github.com/RaRe-Technologies/gensim/blob/6a289fefd72f038c8cc14826f63624950f5de1f8/gensim/utils.py#L907 |
||
self.num_trees = d['num_trees'] | ||
self.index = AnnoyIndex(d['f']) | ||
self.index.load(fname) | ||
self.labels = d['labels'] | ||
|
||
def build_from_word2vec(self): | ||
"""Build an Annoy index using word vectors from a Word2Vec model""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -464,6 +464,32 @@ def testApproxNeighborsMatchExact(self): | |
|
||
self.assertEqual(approx_words, exact_words) | ||
|
||
def testSave(self): | ||
self.index.save('index') | ||
self.assertTrue(os.path.exists('index')) | ||
self.assertTrue(os.path.exists('index.d')) | ||
|
||
def testLoadNotExist(self): | ||
from gensim.similarities.index import AnnoyIndexer | ||
self.test_index = AnnoyIndexer() | ||
self.test_index.load('test-index') | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has to raise IOError |
||
self.assertEqual(self.test_index.index, None) | ||
self.assertEqual(self.test_index.labels, None) | ||
|
||
def testSaveLoad(self): | ||
from gensim.similarities.index import AnnoyIndexer | ||
|
||
self.index.save('index') | ||
|
||
self.index2 = AnnoyIndexer() | ||
self.index2.load('index') | ||
self.index2.model = self.model | ||
|
||
self.assertEqual(self.index.index.f, self.index2.index.f) | ||
self.assertEqual(self.index.labels, self.index2.labels) | ||
self.assertEqual(self.index.num_trees, self.index2.num_trees) | ||
|
||
|
||
class TestDoc2VecAnnoyIndexer(unittest.TestCase): | ||
|
||
|
@@ -497,6 +523,32 @@ def testApproxNeighborsMatchExact(self): | |
|
||
self.assertEqual(approx_words, exact_words) | ||
|
||
def testSave(self): | ||
self.index.save('index') | ||
self.assertTrue(os.path.exists('index')) | ||
self.assertTrue(os.path.exists('index.d')) | ||
|
||
def testLoadNotExist(self): | ||
from gensim.similarities.index import AnnoyIndexer | ||
self.test_index = AnnoyIndexer() | ||
self.test_index.load('test-index') | ||
|
||
self.assertEqual(self.test_index.index, None) | ||
self.assertEqual(self.test_index.labels, None) | ||
|
||
def testSaveLoad(self): | ||
from gensim.similarities.index import AnnoyIndexer | ||
|
||
self.index.save('index') | ||
|
||
self.index2 = AnnoyIndexer() | ||
self.index2.load('index') | ||
self.index2.model = self.model | ||
|
||
self.assertEqual(self.index.index.f, self.index2.index.f) | ||
self.assertEqual(self.index.labels, self.index2.labels) | ||
self.assertEqual(self.index.num_trees, self.index2.num_trees) | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use smart_open as in https://github.com/RaRe-Technologies/gensim/blob/6a289fefd72f038c8cc14826f63624950f5de1f8/gensim/utils.py#L896