Skip to content

Commit

Permalink
Adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TLouf committed Apr 15, 2022
1 parent b824cd5 commit 0a39416
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
8 changes: 4 additions & 4 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
import gensim.models.fasttext

try:
from pyemd import emd # noqa:F401
PYEMD_EXT = True
from ot import emd2 # noqa:F401
POT_EXT = True
except (ImportError, ValueError):
PYEMD_EXT = False
POT_EXT = False

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -394,7 +394,7 @@ def test_contains(self):
self.assertFalse('nights' in self.test_model.wv.key_to_index)
self.assertTrue('nights' in self.test_model.wv)

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_wm_distance(self):
doc = ['night', 'payment']
oov_doc = ['nights', 'forests', 'payments']
Expand Down
48 changes: 24 additions & 24 deletions gensim/test/test_similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
from gensim.similarities.fastss import editdist

try:
from pyemd import emd # noqa:F401
PYEMD_EXT = True
from ot import emd2 # noqa:F401
POT_EXT = True
except (ImportError, ValueError):
PYEMD_EXT = False
POT_EXT = False

SENTENCES = [doc2vec.TaggedDocument(words, [i]) for i, words in enumerate(TEXTS)]

Expand Down Expand Up @@ -88,8 +88,8 @@ def test_full(self, num_best=None, shardsize=100):
index.destroy()

def test_num_best(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

for num_best in [None, 0, 1, 9, 1000]:
self.testFull(num_best=num_best)
Expand Down Expand Up @@ -119,8 +119,8 @@ def test_scipy2scipy_clipped(self):

def test_empty_query(self):
index = self.factoryMethod()
if isinstance(index, similarities.WmdSimilarity) and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if isinstance(index, similarities.WmdSimilarity) and not POT_EXT:
self.skipTest("POT not installed")

query = []
try:
Expand Down Expand Up @@ -177,8 +177,8 @@ def test_iter(self):
index.destroy()

def test_persistency(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

fname = get_tmpfile('gensim_similarities.tst.pkl')
index = self.factoryMethod()
Expand All @@ -197,8 +197,8 @@ def test_persistency(self):
self.assertEqual(index.num_best, index2.num_best)

def test_persistency_compressed(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

fname = get_tmpfile('gensim_similarities.tst.pkl.gz')
index = self.factoryMethod()
Expand All @@ -217,8 +217,8 @@ def test_persistency_compressed(self):
self.assertEqual(index.num_best, index2.num_best)

def test_large(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

fname = get_tmpfile('gensim_similarities.tst.pkl')
index = self.factoryMethod()
Expand All @@ -239,8 +239,8 @@ def test_large(self):
self.assertEqual(index.num_best, index2.num_best)

def test_large_compressed(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

fname = get_tmpfile('gensim_similarities.tst.pkl.gz')
index = self.factoryMethod()
Expand All @@ -261,8 +261,8 @@ def test_large_compressed(self):
self.assertEqual(index.num_best, index2.num_best)

def test_mmap(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

fname = get_tmpfile('gensim_similarities.tst.pkl')
index = self.factoryMethod()
Expand All @@ -284,8 +284,8 @@ def test_mmap(self):
self.assertEqual(index.num_best, index2.num_best)

def test_mmap_compressed(self):
if self.cls == similarities.WmdSimilarity and not PYEMD_EXT:
self.skipTest("pyemd not installed")
if self.cls == similarities.WmdSimilarity and not POT_EXT:
self.skipTest("POT not installed")

fname = get_tmpfile('gensim_similarities.tst.pkl.gz')
index = self.factoryMethod()
Expand All @@ -310,7 +310,7 @@ def factoryMethod(self):
# Override factoryMethod.
return self.cls(TEXTS, self.w2v_model)

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_full(self, num_best=None):
# Override testFull.

Expand All @@ -329,7 +329,7 @@ def test_full(self, num_best=None):
self.assertTrue(numpy.alltrue(sims[1:] > 0.0))
self.assertTrue(numpy.alltrue(sims[1:] < 1.0))

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_non_increasing(self):
''' Check that similarities are non-increasing when `num_best` is not
`None`.'''
Expand All @@ -345,7 +345,7 @@ def test_non_increasing(self):
cond = sum(numpy.diff(sims2) < 0) == len(sims2) - 1
self.assertTrue(cond)

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_chunking(self):
# Override testChunking.

Expand All @@ -364,7 +364,7 @@ def test_chunking(self):
self.assertTrue(numpy.alltrue(sim > 0.0))
self.assertTrue(numpy.alltrue(sim <= 1.0))

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_iter(self):
# Override testIter.

Expand All @@ -373,7 +373,7 @@ def test_iter(self):
self.assertTrue(numpy.alltrue(sims >= 0.0))
self.assertTrue(numpy.alltrue(sims <= 1.0))

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_str(self):
index = self.cls(TEXTS, self.w2v_model)
self.assertTrue(str(index))
Expand Down
12 changes: 6 additions & 6 deletions gensim/test/test_word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from testfixtures import log_capture

try:
from pyemd import emd # noqa:F401
PYEMD_EXT = True
from ot import emd2 # noqa:F401
POT_EXT = True
except (ImportError, ValueError):
PYEMD_EXT = False
POT_EXT = False

from gensim import utils
from gensim.models import word2vec, keyedvectors
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def test_negative_ns_exp(self):

class TestWMD(unittest.TestCase):

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_nonzero(self):
'''Test basic functionality with a test sentence.'''

Expand All @@ -1094,7 +1094,7 @@ def test_nonzero(self):
# Check that distance is non-zero.
self.assertFalse(distance == 0.0)

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_symmetry(self):
'''Check that distance is symmetric.'''

Expand All @@ -1105,7 +1105,7 @@ def test_symmetry(self):
distance2 = model.wv.wmdistance(sentence2, sentence1)
self.assertTrue(np.allclose(distance1, distance2))

@unittest.skipIf(PYEMD_EXT is False, "pyemd not installed")
@unittest.skipIf(POT_EXT is False, "POT not installed")
def test_identical_sentences(self):
'''Check that the distance from a sentence to itself is zero.'''

Expand Down

0 comments on commit 0a39416

Please sign in to comment.