-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
156 lines (120 loc) · 4.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import tokenization
import torch
import collections
from torch.utils.data import Dataset
# VOCAB = 'data/vocab.txt'
# TOKENIZER = tokenization.FullTokenizer(vocab_file=VOCAB, do_lower_case=True)
def load_dataset(path, DIM=300, lower=True):
origin_words, origin_repre = list(), list()
all_embs = dict()
cnt = 0
for line in open(path, encoding='utf8'):
cnt += 1
row = line.strip().split(' ')
if len(row) != DIM + 1:continue
word = row[0]
if lower:
word = str.lower(word)
if filter(word): continue
emb = [float(e) for e in row[1:]]
origin_repre.append(emb)
origin_words.append(word)
all_embs[word] = emb
# add <unk> token
emb = [0.0 for _ in range(DIM)]
origin_repre.append(emb)
origin_words.append('<unk>')
all_embs['<unk>'] = emb
print('loaded! Word num = {a}'.format(a=len(origin_words)))
return {'origin_word': origin_words, 'origin_repre':origin_repre}, all_embs
def load_predict_dataset(path):
origin_words, origin_repre = list(), list()
for line in open(path, encoding='utf8'):
word = line.strip()
origin_repre.append(word)
origin_words.append(word)
print('loaded! Word num = {a}'.format(a=len(origin_words)))
return {'origin_word': origin_words, 'origin_repre':origin_repre}
class TextData(Dataset):
def __init__(self, data):
self.origin_word = data['origin_word']
self.origin_repre = data['origin_repre']
#self.repre_ids = data['repre_ids']
def __len__(self):
return len(self.origin_word)
def __getitem__(self, idx):
return self.origin_word[idx], self.origin_repre[idx]
def collate_fn(batch_data, TOKENIZER, pad=0):
batch_words, batch_oririn_repre = list(zip(*batch_data))
aug_words, aug_repre, aug_ids = list(), list(), list()
for index in range(len(batch_words)):
#aug_word = get_random_attack(batch_words[index])
aug_word = batch_words[index]
repre, repre_ids = repre_word(aug_word, TOKENIZER, id_mapping=None)
aug_words.append(aug_word)
aug_repre.append(repre)
aug_ids.append(repre_ids)
batch_words = list(batch_words) + aug_words
batch_oririn_repre = torch.FloatTensor(batch_oririn_repre)
x_lens = [len(x) for x in aug_ids]
max_len = max([len(seq) for seq in aug_ids])
batch_aug_repre_ids = [char + [pad]*(max_len - len(char)) for char in aug_ids]
batch_aug_repre_ids = torch.LongTensor(batch_aug_repre_ids)
return batch_words, batch_oririn_repre, batch_aug_repre_ids, x_lens
def collate_fn_predict(batch_data, TOKENIZER, rtype='mixed', pad=0):
batch_words, batch_oririn_repre = list(zip(*batch_data))
batch_repre_ids = list()
for word in batch_words:
repre, repre_id = repre_word(word, TOKENIZER, id_mapping=None, rtype=rtype)
batch_repre_ids.append(repre_id)
max_len = max([len(seq) for seq in batch_repre_ids])
batch_repre_ids = [char + [pad]*(max_len - len(char)) for char in batch_repre_ids]
batch_repre_ids = torch.LongTensor(batch_repre_ids)
mask = torch.ne(batch_repre_ids, pad).unsqueeze(2)
return batch_words, batch_oririn_repre, batch_repre_ids, mask
def filter(word):
min_len = 1
if len(word) < min_len:return True
return False
def tokenize_and_getid(word, tokenizer):
tokens = tokenizer.tokenize(tokenizer.convert_to_unicode(word))
token_ids = tokenizer.convert_tokens_to_ids(tokens)
return tokens, token_ids
def hash_sub_word(total, bucket):
bucket -= 1
id_mapping = collections.OrderedDict()
for id in range(total):
hashing = ((id % bucket) ^ 2) + 1
#print(id, hashing)
id_mapping[id] = hashing
id_mapping[0] = 0
id_mapping[100] = bucket + 2
id_mapping[101] = bucket + 3
id_mapping[102] = bucket + 4
id_mapping[103] = bucket + 5
id_mapping[104] = bucket + 6
return id_mapping
def repre_word(word, tokenizer, id_mapping=None, rtype='mixed'):
start = '[CLS]'
sub = '[SUB]'
end = '[SEP]'
char_seq = list(word)
tokens, _ = tokenize_and_getid(word, tokenizer)
if rtype == 'mixed':
repre = [start] + char_seq + [sub] + tokens + [end]
elif rtype == 'char':
repre = [start] + char_seq + [end]
else:
repre = [start] + tokens + [end]
repre_ids = tokenizer.convert_tokens_to_ids(repre)
if id_mapping:
repre_ids = [id_mapping[r_id] for r_id in repre_ids]
return repre, repre_ids
def load_neg_samples(path):
neg_samples = dict()
for line in open(path, encoding='utf8'):
row = line.strip().split('\t')
neg_samples[row[0]] = row[1:]
return neg_samples
if __name__ == '__main__':
pass