-
Notifications
You must be signed in to change notification settings - Fork 24
/
dialogue_multi_trun.py
67 lines (55 loc) · 2.13 KB
/
dialogue_multi_trun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "DUTIR-BioNLP/Taiyi-LLM"
device = 'cuda:0'
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map = device
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
import logging
logging.disable(logging.WARNING)
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.eod_id
tokenizer.eos_token_id = tokenizer.eod_id
history_token_ids = torch.tensor([[]], dtype=torch.long)
max_new_tokens = 500
top_p = 0.9
temperature = 0.3
repetition_penalty = 1.0
# 开始对话
history_max_len = 1000
utterance_id = 0
history_token_ids = None
while True:
user_input = input("用户提问:")
if user_input == "结束对话":
history_token_ids = None # 清空历史对话
print("对话已经结束")
break
input_ids = tokenizer(user_input, return_tensors="pt", add_special_tokens=False).input_ids
bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long)
user_input_ids = torch.concat([bos_token_id,input_ids, eos_token_id], dim=1)
if history_token_ids is None:
history_token_ids = user_input_ids
else:
history_token_ids = torch.cat((history_token_ids, user_input_ids), dim=1)
model_input_ids = history_token_ids[:, -history_max_len:].to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=model_input_ids, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p,
temperature=temperature, repetition_penalty=repetition_penalty, eos_token_id=tokenizer.eos_token_id
)
model_input_ids_len = model_input_ids.size(1)
response_ids = outputs[:, model_input_ids_len:]
history_token_ids = torch.concat((history_token_ids, response_ids.cpu()), dim=1)
response = tokenizer.batch_decode(response_ids)
print("Taiyi回答:" + response[0].strip().replace(tokenizer.eos_token, ""))