forked from jingyaogong/minimind
-
Notifications
You must be signed in to change notification settings - Fork 0
/
2-eval.py
183 lines (151 loc) · 6.2 KB
/
2-eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import random
import time
import numpy as np
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer
from model.LMConfig import LMConfig
warnings.filterwarnings('ignore')
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model_from = 1 # 1从权重,2用transformers
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)
# 处理不需要的前缀
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
for k, v in list(state_dict.items()):
if 'mask' in k:
del state_dict[k]
# 加载到模型中
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
model = model.to(device)
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
return model, tokenizer
def setup_seed(seed):
random.seed(seed) # 设置 Python 的随机种子
np.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
torch.cuda.manual_seed(seed) # 为当前 GPU 设置随机种子(如果有)
torch.cuda.manual_seed_all(seed) # 为所有 GPU 设置随机种子(如果有)
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
if __name__ == "__main__":
# -----------------------------------------------------------------------------
out_dir = 'out'
start = ""
temperature = 0.7
top_k = 16
setup_seed(1337)
# device = 'cpu'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
max_seq_len = 1 * 1024
lm_config = LMConfig()
lm_config.max_seq_len = max_seq_len
# 对话是否携带历史对话(当前模型没有在连续对话数据集上训练,增大历史上文基本不会有新的问答能力)
contain_history_chat = False
# -----------------------------------------------------------------------------
model, tokenizer = init_model(lm_config)
model = model.eval()
# 推送到huggingface
# model.push_to_hub("minimind")
# tokenizer.push_to_hub("minimind")
# answer_way = int(input('输入0自动测试,输入1问题测试:'))
answer_way = 0
stream = True
prompt_datas = [
'你叫什么名字啊?',
'你叫什么名字?',
'中国有哪些比较好的大学?',
'全世界最好的大学是什么?',
'你知道光速是多少吗?',
'你知道长江吗?',
'人类的血液主要由哪些成分组成?',
'第一颗人造卫星是哪个国家发射的?',
'你知道杭州有什么美食吗?',
'你知道泰山在哪里吗?',
'地球上最大的动物是什么?',
'地球自转一圈大约需要多少时间?',
'人类最早使用的金属是什么?',
'水的化学分子式是什么?',
'大气层中含量最多的气体是什么?',
'世界上最高的山峰是什么?',
'你知道世界上最深的海沟是什么吗?',
'最早发明印刷术的是哪个国家?',
'万有引力是谁提出的?',
'光合作用的主要原理是什么?',
'你知道大熊猫的主要食物是什么吗?',
'海水为什么是咸的?',
'我们平时喝的牛奶主要含有什么营养成分?',
'一星期有多少天?'
]
messages_origin = []
messages = messages_origin
i = 0
while i < len(prompt_datas):
if not contain_history_chat:
messages = messages_origin.copy()
if answer_way == 1:
prompt = input('[Q]: ')
else:
prompt = prompt_datas[i]
print(f'[Q]: {prompt}')
i += 1
messages.append({"role": "user", "content": prompt})
# print(messages)
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-(max_seq_len - 1):]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
answer = new_prompt
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature,
top_k=top_k, stream=stream)
print('[A]: ', end='')
try:
y = next(res_y)
except StopIteration:
print("No answer")
continue
history_idx = 0
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '�':
try:
y = next(res_y)
except:
break
continue
# print(answer)
if not len(answer):
try:
y = next(res_y)
except:
break
continue
print(answer[history_idx:], end='', flush=True)
try:
y = next(res_y)
except:
break
history_idx = len(answer)
if not stream:
break
print('\n')
if contain_history_chat:
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})