-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Poincare Model implementation #1696
Changes from 108 commits
6afdd22
a804006
6bd0d4b
98f94a7
b727523
1e6aee1
e286a0b
3e28e8b
99a2270
2e9e31c
d72cb10
d439501
e1ed24d
3b2a383
7d68aae
ba82d42
71f61d1
2a5a7fb
0c57aa1
9c51609
f22d9b2
7905c8c
075df25
6060e56
8ea8f23
b8d77e3
0011b93
b52ee2e
d247384
8c4f5a3
34b0ad3
1779cd7
faacb43
c68088e
0c2f2cb
386f602
f0fb9e9
7d8fbec
315f95c
0802dd5
a106191
1aa586d
13b00dc
5978af6
e40c3e3
ec8b516
86ae4d6
ac51e9c
5900c6f
4ac4d2e
9eb6f48
2ded72b
81960e1
5de194b
6dd6915
4b502af
12be121
29e799c
b4ff1dd
e2f72bc
953b4a7
eebc12a
21a1c82
f9325ea
5db8456
53030a0
db0d293
1adf81a
3898089
5cd913a
6305228
fb13eb5
ea2fd48
110fb1e
0aeec2f
38feb7a
0e7ebb3
52a1e57
630771d
7c6d972
16dcf0b
3501d6f
d690a25
2383e82
3ed0bea
9f562cb
98e078d
530146d
d17c075
be0249a
dc2ab95
a306f20
f9750e6
b7212ff
3556ee4
4644eda
94a2a18
7a4ec79
613ca38
3029d41
055044c
f75491f
59fcf8b
dcbe7aa
b69f51f
001ec76
9446a05
930dfd4
355e521
0d5175c
00ca7ab
8ff23ae
30ac3e6
a928ca1
dfc19cb
e967c54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
:mod:`models.poincare` -- Train and use Poincare embeddings | ||
============================================================= | ||
|
||
.. automodule:: gensim.models.poincare | ||
:synopsis: Train and use Poincare embeddings | ||
:members: | ||
:inherited-members: | ||
:special-members: __iter__, __getitem__, __contains__ | ||
:undoc-members: | ||
:show-inheritance: |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
t�mto bude� | ||
budem byli |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
kangaroo.n.01 marsupial.n.01 | ||
kangaroo.n.01 metatherian.n.01 | ||
kangaroo.n.01 mammal.n.01 | ||
gib.n.02 cat.n.01 | ||
striped_skunk.n.01 mammal.n.01 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
kangaroo.n.01 marsupial.n.01 | ||
kangaroo.n.01 metatherian.n.01 | ||
kangaroo.n.01 mammal.n.01 | ||
gib.n.02 cat.n.01 | ||
striped_skunk.n.01 mammal.n.01 | ||
domestic_goat.n.01 even-toed_ungulate.n.01 | ||
rock_squirrel.n.01 ground_squirrel.n.02 | ||
vizsla.n.01 dog.n.01 | ||
dandie_dinmont.n.01 mammal.n.01 | ||
broodmare.n.01 horse.n.01 | ||
spotted_skunk.n.01 spotted_skunk.n.01 | ||
hispid_pocket_mouse.n.01 hispid_pocket_mouse.n.01 | ||
lesser_kudu.n.01 placental.n.01 | ||
water_shrew.n.01 insectivore.n.01 | ||
silky_anteater.n.01 placental.n.01 | ||
giant_kangaroo.n.01 metatherian.n.01 | ||
bronco.n.01 bronco.n.01 | ||
pekinese.n.01 pekinese.n.01 | ||
seattle_slew.n.01 thoroughbred.n.02 | ||
kinkajou.n.01 kinkajou.n.01 | ||
boxer.n.04 mammal.n.01 | ||
rabbit.n.01 placental.n.01 | ||
longhorn.n.01 bovid.n.01 | ||
blue_fox.n.01 fox.n.01 | ||
woolly_monkey.n.01 new_world_monkey.n.01 | ||
jungle_cat.n.01 jungle_cat.n.01 | ||
vole.n.01 mammal.n.01 | ||
western_big-eared_bat.n.01 long-eared_bat.n.01 | ||
leopard.n.02 leopard.n.02 | ||
hackney.n.02 hackney.n.02 | ||
shetland_sheepdog.n.01 placental.n.01 | ||
coati.n.01 carnivore.n.01 | ||
wild_boar.n.01 mammal.n.01 | ||
post_horse.n.01 placental.n.01 | ||
porker.n.01 porker.n.01 | ||
mouflon.n.01 mouflon.n.01 | ||
australian_sea_lion.n.01 seal.n.09 | ||
coondog.n.01 placental.n.01 | ||
schipperke.n.01 mammal.n.01 | ||
black_rat.n.01 rodent.n.01 | ||
waterbuck.n.01 placental.n.01 | ||
hack.n.06 odd-toed_ungulate.n.01 | ||
central_chimpanzee.n.01 anthropoid_ape.n.01 | ||
harrier.n.02 harrier.n.02 | ||
lesser_panda.n.01 mammal.n.01 | ||
wether.n.01 ruminant.n.01 | ||
collie.n.01 shepherd_dog.n.01 | ||
prancer.n.01 horse.n.01 | ||
doberman.n.01 placental.n.01 | ||
pygmy_marmoset.n.01 monkey.n.01 | ||
phalanger.n.01 metatherian.n.01 | ||
black-and-tan_coonhound.n.01 black-and-tan_coonhound.n.01 | ||
woolly_monkey.n.01 primate.n.02 | ||
ferret_badger.n.01 badger.n.02 | ||
mountain_chinchilla.n.01 placental.n.01 | ||
english_foxhound.n.01 english_foxhound.n.01 | ||
leveret.n.01 leporid.n.01 | ||
shetland_sheepdog.n.01 canine.n.02 | ||
beagle.n.01 beagle.n.01 | ||
tibetan_mastiff.n.01 tibetan_mastiff.n.01 | ||
bouvier_des_flandres.n.01 canine.n.02 | ||
wheel_horse.n.01 placental.n.01 | ||
pocket_rat.n.01 rat.n.01 | ||
malinois.n.01 working_dog.n.01 | ||
white_elephant.n.02 white_elephant.n.02 | ||
camel.n.01 camel.n.01 | ||
mexican_pocket_mouse.n.01 rat.n.01 | ||
vaquita.n.01 toothed_whale.n.01 | ||
manchester_terrier.n.01 hunting_dog.n.01 | ||
chacma.n.01 monkey.n.01 | ||
binturong.n.01 viverrine.n.01 | ||
mastiff_bat.n.01 mammal.n.01 | ||
goat.n.01 mammal.n.01 | ||
pembroke.n.01 canine.n.02 | ||
steenbok.n.01 ungulate.n.01 | ||
tarsius_syrichta.n.01 mammal.n.01 | ||
maltese.n.03 domestic_cat.n.01 | ||
pacific_bottlenose_dolphin.n.01 toothed_whale.n.01 | ||
tamandua.n.01 tamandua.n.01 | ||
murine.n.01 rodent.n.01 | ||
coyote.n.01 canine.n.02 | ||
king_charles_spaniel.n.01 placental.n.01 | ||
basset.n.01 canine.n.02 | ||
pygmy_mouse.n.01 pygmy_mouse.n.01 | ||
toy_spaniel.n.01 carnivore.n.01 | ||
cactus_mouse.n.01 mouse.n.01 | ||
hart.n.03 ruminant.n.01 | ||
broodmare.n.01 equine.n.01 | ||
sussex_spaniel.n.01 sporting_dog.n.01 | ||
omaha.n.04 odd-toed_ungulate.n.01 | ||
alaska_fur_seal.n.01 placental.n.01 | ||
cattalo.n.01 bovine.n.01 | ||
soft-coated_wheaten_terrier.n.01 mammal.n.01 | ||
harness_horse.n.01 horse.n.01 | ||
banteng.n.01 even-toed_ungulate.n.01 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tímto budeš | ||
budem byli |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Author: Jayant Jain <jayantjain1992@gmail.com> | ||
# Copyright (C) 2017 Radim Rehurek <me@radimrehurek.com> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Automated tests for checking the poincare module from the models package. | ||
""" | ||
|
||
import logging | ||
import os | ||
import tempfile | ||
import unittest | ||
try: | ||
from mock import Mock | ||
except ImportError: | ||
from unittest.mock import Mock | ||
|
||
import numpy as np | ||
|
||
from gensim.models.poincare import PoincareRelations, PoincareModel | ||
from gensim.test.utils import datapath | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def testfile(): | ||
# temporary data will be stored to this file | ||
return os.path.join(tempfile.gettempdir(), 'gensim_word2vec.tst') | ||
|
||
|
||
class TestPoincareData(unittest.TestCase): | ||
def test_encoding_handling(self): | ||
"""Tests whether utf8 and non-utf8 data loaded correctly.""" | ||
non_utf8_file = datapath('poincare_cp852.tsv') | ||
relations = [relation for relation in PoincareRelations(non_utf8_file, encoding='cp852')] | ||
self.assertEqual(len(relations), 2) | ||
self.assertEqual(relations[0], (u'tímto', u'budeš')) | ||
|
||
utf8_file = datapath('poincare_utf8.tsv') | ||
relations = [relation for relation in PoincareRelations(utf8_file)] | ||
self.assertEqual(len(relations), 2) | ||
self.assertEqual(relations[0], (u'tímto', u'budeš')) | ||
|
||
|
||
class TestPoincareModel(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice test coverage! What I'm missing is a test that checks the embedding actually works. Would it be possible to train on some very small data, such as this graph:
and assert, for instance, that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I plan on adding this, it relies on the |
||
def setUp(self): | ||
self.data = PoincareRelations(datapath('poincare_hypernyms.tsv')) | ||
self.data_large = PoincareRelations(datapath('poincare_hypernyms_large.tsv')) | ||
|
||
def models_equal(self, model_1, model_2): | ||
self.assertEqual(len(model_1.kv.vocab), len(model_2.kv.vocab)) | ||
self.assertEqual(set(model_1.kv.vocab.keys()), set(model_2.kv.vocab.keys())) | ||
self.assertTrue(np.allclose(model_1.kv.syn0, model_2.kv.syn0)) | ||
|
||
def test_data_counts(self): | ||
"""Tests whether data has been loaded correctly and completely.""" | ||
model = PoincareModel(self.data) | ||
self.assertEqual(len(model.all_relations), 5) | ||
self.assertEqual(len(model.node_relations[model.kv.vocab['kangaroo.n.01'].index]), 3) | ||
self.assertEqual(len(model.kv.vocab), 7) | ||
self.assertTrue('mammal.n.01' not in model.node_relations) | ||
|
||
def test_persistence(self): | ||
"""Tests whether the model is saved and loaded correctly.""" | ||
model = PoincareModel(self.data, burn_in=0, negative=3) | ||
model.train(epochs=1) | ||
model.save(testfile()) | ||
loaded = PoincareModel.load(testfile()) | ||
self.models_equal(model, loaded) | ||
|
||
def test_persistence_separate_file(self): | ||
"""Tests whether the model is saved and loaded correctly when the arrays are stored separately.""" | ||
model = PoincareModel(self.data, burn_in=0, negative=3) | ||
model.train(epochs=1) | ||
model.save(testfile(), sep_limit=1) | ||
loaded = PoincareModel.load(testfile()) | ||
self.models_equal(model, loaded) | ||
|
||
def test_invalid_data_raises_error(self): | ||
"""Tests that error is raised on invalid input data.""" | ||
with self.assertRaises(ValueError): | ||
PoincareModel([("a", "b", "c")]) | ||
with self.assertRaises(ValueError): | ||
PoincareModel(["a", "b", "c"]) | ||
with self.assertRaises(ValueError): | ||
PoincareModel("ab") | ||
|
||
def test_vector_shape(self): | ||
"""Tests whether vectors are initialized with the correct size.""" | ||
model = PoincareModel(self.data, size=20) | ||
self.assertEqual(model.kv.syn0.shape, (7, 20)) | ||
|
||
def test_training(self): | ||
"""Tests that vectors are different before and after training.""" | ||
model = PoincareModel(self.data_large, burn_in=0, negative=3) | ||
old_vectors = np.copy(model.kv.syn0) | ||
model.train(epochs=2) | ||
self.assertFalse(np.allclose(old_vectors, model.kv.syn0)) | ||
|
||
def test_training_multiple(self): | ||
"""Tests that calling train multiple times results in different vectors.""" | ||
model = PoincareModel(self.data_large, burn_in=0, negative=3) | ||
model.train(epochs=2) | ||
old_vectors = np.copy(model.kv.syn0) | ||
|
||
model.train(epochs=1) | ||
self.assertFalse(np.allclose(old_vectors, model.kv.syn0)) | ||
|
||
old_vectors = np.copy(model.kv.syn0) | ||
model.train(epochs=0) | ||
self.assertTrue(np.allclose(old_vectors, model.kv.syn0)) | ||
|
||
def test_gradients_check(self): | ||
"""Tests that the gradients check succeeds during training.""" | ||
model = PoincareModel(self.data, negative=3) | ||
old_vectors = np.copy(model.kv.syn0) | ||
model.train(epochs=1, batch_size=1, check_gradients_every=1) | ||
self.assertFalse(np.allclose(old_vectors, model.kv.syn0)) | ||
|
||
def test_wrong_gradients_raises_assertion(self): | ||
"""Tests that discrepancy in gradients raises an error.""" | ||
model = PoincareModel(self.data, negative=3) | ||
model._loss_grad = Mock(return_value=np.zeros((2 + model.negative, model.size))) | ||
with self.assertRaises(AssertionError): | ||
model.train(epochs=1, batch_size=1, check_gradients_every=1) | ||
|
||
def test_reproducible(self): | ||
"""Tests that vectors are same for two independent models trained with the same seed.""" | ||
model_1 = PoincareModel(self.data_large, seed=1, negative=3, burn_in=1) | ||
model_1.train(epochs=2) | ||
|
||
model_2 = PoincareModel(self.data_large, seed=1, negative=3, burn_in=1) | ||
model_2.train(epochs=2) | ||
self.assertTrue(np.allclose(model_1.kv.syn0, model_2.kv.syn0)) | ||
|
||
def test_burn_in(self): | ||
"""Tests that vectors are different after burn-in.""" | ||
model = PoincareModel(self.data, burn_in=1, negative=3) | ||
original_vectors = np.copy(model.kv.syn0) | ||
model.train(epochs=0) | ||
self.assertFalse(np.allclose(model.kv.syn0, original_vectors)) | ||
|
||
def test_burn_in_only_done_once(self): | ||
"""Tests that burn-in does not happen when train is called a second time.""" | ||
model = PoincareModel(self.data, negative=3, burn_in=1) | ||
model.train(epochs=0) | ||
original_vectors = np.copy(model.kv.syn0) | ||
model.train(epochs=0) | ||
self.assertTrue(np.allclose(model.kv.syn0, original_vectors)) | ||
|
||
def test_negatives(self): | ||
"""Tests that correct number of negatives are sampled.""" | ||
model = PoincareModel(self.data, negative=5) | ||
self.assertEqual(len(model._get_candidate_negatives()), 5) | ||
|
||
def test_error_if_negative_more_than_population(self): | ||
"""Tests error is rased if number of negatives to sample is more than remaining nodes.""" | ||
model = PoincareModel(self.data, negative=5) | ||
with self.assertRaises(ValueError): | ||
model.train(epochs=1) | ||
|
||
def test_no_duplicates_and_positives_in_negative_sample(self): | ||
"""Tests that no duplicates or positively related nodes are present in negative samples.""" | ||
model = PoincareModel(self.data_large, negative=3) | ||
positive_nodes = model.node_relations[0] # Positive nodes for node 0 | ||
num_samples = 100 # Repeat experiment multiple times | ||
for i in range(num_samples): | ||
negatives = model._sample_negatives(0) | ||
self.assertFalse(positive_nodes & set(negatives)) | ||
self.assertEqual(len(negatives), len(set(negatives))) | ||
|
||
def test_handle_duplicates(self): | ||
"""Tests that correct number of negatives are used.""" | ||
vector_updates = np.array([[0.5, 0.5], [0.1, 0.2], [0.3, -0.2]]) | ||
node_indices = [0, 1, 0] | ||
PoincareModel._handle_duplicates(vector_updates, node_indices) | ||
vector_updates_expected = np.array([[0.0, 0.0], [0.1, 0.2], [0.8, 0.3]]) | ||
self.assertTrue((vector_updates == vector_updates_expected).all()) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
try: | ||
os.unlink(testfile()) | ||
except OSError: | ||
pass | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tests for save/load.