-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparaphrase_openai.py
91 lines (84 loc) · 2.42 KB
/
paraphrase_openai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import openai
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
from datasets import load_from_disk, Dataset
import argparse
from transformers import AutoTokenizer
from paraphrase_gen_util import accept_by_bigram_overlap, extract_list
import time
openai.api_key = "your/key"
def gen_prompt(sent, context):
prompt = f'''Previous context: {context} \n Current sentence to paraphrase: {sent}'''
return prompt
def gen_bigram_prompt(sent, context):
prompt = f'''Previous context: {context} \n Paraphrase in 20 different ways and return a numbered list : {sent}'''
return prompt
def query_openai(prompt):
while True:
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "user",
"content": prompt
}
],
temperature=1,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
return response.choices[0].message.content
except openai.error.RateLimitError:
time.sleep(5)
except openai.error.APIError:
time.sleep(2)
def query_openai_bigram(prompt):
while True:
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-16k",
messages=[
{
"role": "user",
"content": prompt
}
],
temperature=1,
max_tokens=4096,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
return response.choices[0].message.content
except openai.error.RateLimitError:
time.sleep(5)
except openai.error.APIError:
time.sleep(2)
def paraphrase_openai_(texts, tokenizer, bigram=False):
new_texts = []
paras = []
for text in tqdm(texts, desc="Tokenizer"):
sents = sent_tokenize(text)
para = []
for i in range(len(sents)):
sent = sents[i]
context = sents[:i]
if bigram:
prompt = gen_bigram_prompt(sent, context)
para_str = query_openai_bigram(prompt)
para_ls = extract_list(para_str)
if len(para_ls) < 20:
print(para_str)
print(para_ls)
continue
para.append(accept_by_bigram_overlap(sent, para_ls, tokenizer)) #
else:
prompt = gen_prompt(sent, context)
para_sen = query_openai(prompt)
para.append(para_sen)
new_texts.append(sents)
paras.append(" ".join(para))
return new_texts, paras