Skip to content

Commit

Permalink
Adds option to specify dtype for PoincareModel and corresponding unit…
Browse files Browse the repository at this point in the history
…test
  • Loading branch information
jayantj committed Nov 15, 2017
1 parent a928ca1 commit dfc19cb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
7 changes: 5 additions & 2 deletions gensim/models/poincare.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class PoincareModel(utils.SaveLoad):
"""
def __init__(self, train_data, size=50, alpha=0.1, negative=10, workers=1, epsilon=1e-5,
burn_in=10, burn_in_alpha=0.01, init_range=(-0.001, 0.001), seed=0):
burn_in=10, burn_in_alpha=0.01, init_range=(-0.001, 0.001), dtype=np.float64, seed=0):
"""Initialize and train a Poincare embedding model from an iterable of relations.
Parameters
Expand All @@ -95,6 +95,8 @@ def __init__(self, train_data, size=50, alpha=0.1, negative=10, workers=1, epsil
Learning rate for burn-in initialization, ignored if `burn_in` is 0.
init_range : 2-tuple (float, float)
Range within which the vectors are randomly initialized.
dtype : numpy.dtype
The numpy dtype to use for the vectors in the model (numpy.float64, numpy.float32 etc).
seed : int, optional
Seed for random to ensure reproducibility.
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(self, train_data, size=50, alpha=0.1, negative=10, workers=1, epsil
self.epsilon = epsilon
self.burn_in = burn_in
self._burn_in_done = False
self.dtype = dtype
self.seed = seed
self._np_random = np_random.RandomState(seed)
self.init_range = init_range
Expand Down Expand Up @@ -172,7 +175,7 @@ def _load_relations(self):
def _init_embeddings(self):
"""Randomly initialize vectors for the items in the vocab."""
shape = (len(self.kv.index2word), self.size)
self.kv.syn0 = self._np_random.uniform(self.init_range[0], self.init_range[1], shape)
self.kv.syn0 = self._np_random.uniform(self.init_range[0], self.init_range[1], shape).astype(self.dtype)

def _get_candidate_negatives(self):
"""Returns candidate negatives of size `self.negative` from the negative examples buffer.
Expand Down
5 changes: 5 additions & 0 deletions gensim/test/test_poincare.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def test_vector_shape(self):
model = PoincareModel(self.data, size=20)
self.assertEqual(model.kv.syn0.shape, (7, 20))

def test_vector_dtype(self):
"""Tests whether vectors are initialized with the correct dtype."""
model = PoincareModel(self.data, dtype=np.float32)
self.assertEqual(model.kv.syn0.dtype, np.float32)

def test_training(self):
"""Tests that vectors are different before and after training."""
model = PoincareModel(self.data_large, burn_in=0, negative=3)
Expand Down

0 comments on commit dfc19cb

Please sign in to comment.