Skip to content

Commit

Permalink
Change vector training to work with latest gensim (fix #3749) (#3757)
Browse files Browse the repository at this point in the history
  • Loading branch information
polm authored and honnibal committed Jun 16, 2019
1 parent d8573ee commit 3f52e12
Showing 1 changed file with 13 additions and 38 deletions.
51 changes: 13 additions & 38 deletions bin/train_word_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,21 @@


class Corpus(object):
def __init__(self, directory, min_freq=10):
def __init__(self, directory, nlp):
self.directory = directory
self.counts = PreshCounter()
self.strings = {}
self.min_freq = min_freq

def count_doc(self, doc):
# Get counts for this document
for word in doc:
self.counts.inc(word.orth, 1)
return len(doc)
self.nlp = nlp

def __iter__(self):
for text_loc in iter_dir(self.directory):
with text_loc.open("r", encoding="utf-8") as file_:
text = file_.read()
yield text

# This is to keep the input to the blank model (which doesn't
# sentencize) from being too long. It works particularly well with
# the output of [WikiExtractor](https://github.com/attardi/wikiextractor)
paragraphs = text.split('\n\n')
for par in paragraphs:
yield [word.orth_ for word in self.nlp(par)]


def iter_dir(loc):
Expand Down Expand Up @@ -62,46 +60,23 @@ def main(
window=5,
size=128,
min_count=10,
nr_iter=2,
nr_iter=5,
):
logging.basicConfig(
format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
)
nlp = spacy.blank(lang)
corpus = Corpus(in_dir, nlp)
model = Word2Vec(
sentences=corpus,
size=size,
window=window,
min_count=min_count,
workers=n_workers,
sample=1e-5,
negative=negative,
)
nlp = spacy.blank(lang)
corpus = Corpus(in_dir)
total_words = 0
total_sents = 0
for text_no, text_loc in enumerate(iter_dir(corpus.directory)):
with text_loc.open("r", encoding="utf-8") as file_:
text = file_.read()
total_sents += text.count("\n")
doc = nlp(text)
total_words += corpus.count_doc(doc)
logger.info(
"PROGRESS: at batch #%i, processed %i words, keeping %i word types",
text_no,
total_words,
len(corpus.strings),
)
model.corpus_count = total_sents
model.raw_vocab = defaultdict(int)
for orth, freq in corpus.counts:
if freq >= min_count:
model.raw_vocab[nlp.vocab.strings[orth]] = freq
model.scale_vocab()
model.finalize_vocab()
model.iter = nr_iter
model.train(corpus)
model.save(out_loc)


if __name__ == "__main__":
plac.call(main)

0 comments on commit 3f52e12

Please sign in to comment.