-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparaphrase_gen_util.py
48 lines (37 loc) · 1.46 KB
/
paraphrase_gen_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from collections import Counter
import torch
import re
device = 'cuda' if torch.cuda.is_available() else "cpu"
def build_bigrams(input_ids):
bigrams = []
for i in range(len(input_ids) - 1):
bigram = tuple(input_ids[i:i+2].tolist())
bigrams.append(bigram)
return bigrams
def extract_list(text):
p = re.compile("^[0-9]+[.)\]\*·:] (.*(?:\n(?![0-9]+[.)\]\*·:]).*)*)", re.MULTILINE)
return p.findall(text)
def compare_bigram_overlap(input_bigram, para_bigram):
input_c = Counter(input_bigram)
para_c = Counter(para_bigram)
intersection = list(input_c.keys() & para_c.keys())
overlap = 0
for i in intersection:
overlap += input_c[i]
return overlap
def accept_by_bigram_overlap(sent, para_sents, tokenizer):
def tokenize(tokenizer, text):
return tokenizer(text, return_tensors='pt').input_ids[0].to(device)
input_ids = tokenize(tokenizer, sent)
input_bigram = build_bigrams(input_ids)
para_ids = [tokenize(tokenizer, para) for para in para_sents]
para_bigrams = [build_bigrams(para_id) for para_id in para_ids]
min_overlap = len(input_ids)
paraphrased = para_sents[0]
for i in range(len(para_bigrams)):
para_bigram = para_bigrams[i]
overlap = compare_bigram_overlap(input_bigram, para_bigram)
if overlap < min_overlap and len(para_ids[i]) <= 1.5 * len(input_ids):
min_overlap = overlap
paraphrased = para_sents[i]
return paraphrased