-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·118 lines (95 loc) · 3.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
import argparse
import os
from collections import Counter
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd
from scipy.sparse.linalg import svds
from data import load_docs
from plot import emb_scatter, heatmap
def ppmi(pxy, py, px=None):
xdim, ydim = pxy.shape
assert py.shape == (1, ydim)
if px is None:
pmi = np.log(pxy) - np.log(py)
else:
assert px.shape == (xdim, 1)
pmi = np.log(pxy) - np.log(px) - np.log(py)
pmi[pmi < 0] = 0
return pmi
def make_matrices(unigrams, bigrams, w2i):
assert (len(unigrams) == len(w2i))
py = np.zeros((1, len(unigrams)))
pxy = np.zeros((len(bigrams), len(unigrams)))
for word, i in w2i.items():
py[0, i] = unigrams[word]
for i, title in enumerate(tqdm(bigrams)):
for word, prob in bigrams[title].items():
pxy[i, w2i[word]] = prob
return pxy, py
def write_vectors(vectors, titles, path, gensim=True):
assert (vectors.shape[0] == len(titles))
with open(path, 'w') as f:
if gensim:
print(len(titles), vectors.shape[1], file=f)
for i, title in enumerate(titles):
vector = vectors[i]
title = '_'.join(title.split()) # must be one word
line = ' '.join((str(title),) + tuple(str(val) for val in vector))
print(line, file=f)
def main(args):
print(f'Loading data from `{args.data}`...')
docs = load_docs(args.data)
# Fix an error in the conversion. TODO
docs = {title: text for title, text in docs.items() if isinstance(text, str) and len(text) > 0}
titles = tuple(docs.keys())
all_text = ''
for title in titles:
if args.lower:
docs[title] = docs[title].lower()
all_text += ' ' + docs[title]
vocab = Counter(all_text.split())
if args.num_words is None:
args.num_words = len(vocab)
counts = vocab.most_common(args.num_words)
vocab = [word for word, _ in counts]
print(f'Num docs: {len(titles):,}')
print(f'Num words: {len(vocab):,}')
print('Collecting unigrams...')
total = sum(count for _, count in counts)
unigrams = dict((word, count/total) for word, count in counts)
w2i = dict((word, i) for i, word in enumerate(unigrams.keys()))
print('Collecting bigrams...')
bigrams = dict()
for title in tqdm(titles):
text = (word for word in docs[title].split() if word in vocab)
word_counts = Counter(text)
total = sum(count for _, count in word_counts.items())
bigrams[title] = dict((word, count/total) for word, count in word_counts.items())
print('Making PPMI matrix...')
pxy, py = make_matrices(unigrams, bigrams, w2i)
if args.ppmi:
mat = ppmi(pxy, py)
else:
mat = pxy
U, s, V = svds(mat, k=args.dim)
print('Saving results...')
write_vectors(U, titles, args.outpath)
emb_scatter(U, titles, model_name='wikitext-2', tsne=args.no_tsne, perplexity=args.perplexity)
heatmap(U, 'plots/U.pdf')
heatmap(mat, 'plots/mat.pdf')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', default='data/wikitext-2-raw.docs.json')
parser.add_argument('--outpath', default='vec/doc.vec.txt')
parser.add_argument('--lower', action='store_true')
parser.add_argument('--ppmi', action='store_true')
parser.add_argument('--no-tsne', action='store_false')
parser.add_argument('--num-docs', type=int, default=1000)
parser.add_argument('--num-words', type=int, default=None)
parser.add_argument('--dim', type=int, default=100)
parser.add_argument('--perplexity', type=int, default=30)
args = parser.parse_args()
main(args)