Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding type check for corpus_file argument #2469

Merged
merged 16 commits into from
May 5, 2019
Merged
19 changes: 19 additions & 0 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,12 @@
"""

import logging
import os

import numpy as np
from numpy import ones, vstack, float32 as REAL, sum as np_sum
import six
from collections import Iterable

import gensim.models._fasttext_bin

Expand Down Expand Up @@ -901,6 +903,23 @@ def train(self, sentences=None, corpus_file=None, total_examples=None, total_wor
>>> model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs)

"""

# Check if both sentences and corpus_file are None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these comments. They add no value to the source code.

if corpus_file is None and sentences is None:
raise TypeError("Either one of corpus_file or sentences value must be provided")

# Check if both sentences and corpus_file are not None
if corpus_file is not None and sentences is not None:
raise TypeError("Both corpus_file and sentences must not be provided at the same time")

# Check if corpus_file is string type
if sentences is None and not os.path.isfile(corpus_file):
raise TypeError("Parameter corpus_file must be a valid path to a file, got %r instead" % corpus_file)

# Check if sentences is iterable
if sentences is not None and not isinstance(sentences, Iterable):
raise TypeError("sentences must be an iterable of list, got %r instead" % sentences)

super(FastText, self).train(
sentences=sentences, corpus_file=corpus_file, total_examples=total_examples, total_words=total_words,
epochs=epochs, start_alpha=start_alpha, end_alpha=end_alpha, word_count=word_count,
Expand Down
16 changes: 16 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ def test_training(self):
oov_vec = model.wv['minor'] # oov word
self.assertEqual(len(oov_vec), 10)

def testFastTextTrainParameters(self):

model = FT_gensim(size=10, min_count=1, hs=1, negative=0, seed=42, workers=1)
model.build_vocab(sentences=sentences)

# check if corpus_file is not a string
self.assertRaises(TypeError, model.train, corpus_file=11111)
# check if sentences is an iterable
self.assertRaises(TypeError, model.train, sentences=11111)
# check is both the parameters are provided
self.assertRaises(TypeError, model.train, sentences=sentences, corpus_file='test')
# check if both the parameters are left empty
self.assertRaises(TypeError, model.train, sentences=None, corpus_file=None)
# check if corpus_file is an iterable
self.assertRaises(TypeError, model.train, corpus_file=sentences)

@unittest.skipIf(os.name == 'nt' and six.PY2, "corpus_file training is not supported on Windows + Py27")
def test_training_fromfile(self):
with temporary_file(get_tmpfile('gensim_fasttext.tst')) as corpus_file:
Expand Down