-
Notifications
You must be signed in to change notification settings - Fork 10
/
build_vocab.py
90 lines (72 loc) · 2.29 KB
/
build_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import re
import pickle
from collections import Counter
import nltk
class Vocab:
'''vocabulary'''
def __init__(self):
self.w2i = {}
self.i2w = {}
self.ix = 0
def add_word(self, word):
if word not in self.w2i:
self.w2i[word] = self.ix
self.i2w[self.ix] = word
self.ix += 1
def __call__(self, word):
if word not in self.w2i:
return self.w2i['<unk>']
return self.w2i[word]
def __len__(self):
return len(self.w2i)
def build_vocab(mode_list=['factual', 'humorous']):
'''build vocabulary'''
# define vocabulary
vocab = Vocab()
# add special tokens
vocab.add_word('<pad>')
vocab.add_word('<s>')
vocab.add_word('</s>')
vocab.add_word('<unk>')
# add words
for mode in mode_list:
if mode == 'factual':
captions = extract_captions(mode=mode)
words = nltk.tokenize.word_tokenize(captions)
counter = Counter(words)
words = [word for word, cnt in counter.items() if cnt >= 2]
else:
captions = extract_captions(mode=mode)
words = nltk.tokenize.word_tokenize(captions)
for word in words:
vocab.add_word(word)
return vocab
def extract_captions(mode='factual'):
'''extract captions from data files for building vocabulary'''
text = ''
if mode == 'factual':
with open("data/factual_train.txt", 'r') as f:
res = f.readlines()
r = re.compile(r'\d*.jpg#\d*')
for line in res:
line = r.sub('', line)
line = line.replace('.', '')
line = line.strip()
text += line + ' '
else:
if mode == 'humorous':
with open("data/humor/funny_train.txt", 'r') as f:
res = f.readlines()
else:
with open("data/romantic/romantic_train.txt", 'r') as f:
res = f.readlines()
for line in res:
line = line.replace('.', '')
line = line.strip()
text += line + ' '
return text.strip().lower()
if __name__ == '__main__':
vocab = build_vocab(mode_list=['factual', 'humorous'])
print(vocab.__len__())
with open('data/vocab.pkl', 'wb') as f:
pickle.dump(vocab, f)