-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
generation.py
548 lines (492 loc) · 19.3 KB
/
generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import json
import os
import sys
import time
from pathlib import Path
from typing import List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
Role = Literal["system", "user", "assistant"]
class Message(TypedDict):
role: Role
content: str
destination: str # required for model responses
class InfillingPrediction(TypedDict, total=False):
generation: str
full_text: str
tokens: List[str] # not required
logprobs: List[float] # not required
class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
logprobs: List[float] # not required
class ChatPrediction(TypedDict, total=False):
generation: Message
tokens: List[str] # not required
logprobs: List[float] # not required
Dialog = List[Message]
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>", "<step>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
class Llama:
@staticmethod
def build(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
model_parallel_size: Optional[int] = None,
) -> "Llama":
if not torch.distributed.is_initialized():
if device == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device == "cuda":
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(1)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
# support for mac
if device == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer)
def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
stop_token: Optional[int] = None,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
if stop_token is None:
stop_token = self.tokenizer.eos_id
params = self.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device)
prev_pos = 0
stop_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
stop_reached |= (~input_text_mask[:, cur_pos]) & (next_token == stop_token)
prev_pos = cur_pos
if all(stop_reached):
break
if logprobs:
token_logprobs = token_logprobs.tolist()
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
probs = None
if logprobs:
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
# cut to stop token if present
if stop_token in toks:
stop_idx = toks.index(stop_token)
toks = toks[:stop_idx]
probs = probs[:stop_idx] if logprobs else None
out_tokens.append(toks)
out_logprobs.append(probs)
return (out_tokens, out_logprobs if logprobs else None)
def text_completion(
self,
prompts: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> List[CompletionPrediction]:
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
if logprobs:
assert generation_logprobs is not None
return [
{
"generation": self.tokenizer.decode(t),
"tokens": [self.tokenizer.token_piece(x) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
]
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
def text_infilling(
self,
prefixes: List[str],
suffixes: List[str],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
suffix_first: bool = False,
) -> List[InfillingPrediction]:
assert self.tokenizer.eot_id is not None
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = [
infilling_prompt_tokens(
self.tokenizer, prefix, suffix, suffix_first=suffix_first
)
for prefix, suffix in zip(prefixes, suffixes)
]
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=False,
stop_token=self.tokenizer.eot_id,
)
generations = [self.tokenizer.decode_infilling(t) for t in generation_tokens]
if logprobs:
assert generation_logprobs is not None
return [
{
"generation": generation,
"logprobs": logprobs_i,
"tokens": [self.tokenizer.token_piece(x) for x in t],
"full_text": prefix + generation + suffix,
}
for prefix, suffix, generation, t, logprobs_i in zip(
prefixes,
suffixes,
generations,
generation_tokens,
generation_logprobs,
)
]
else:
return [
{
"generation": generation,
"full_text": prefix + generation + suffix,
}
for prefix, suffix, generation in zip(prefixes, suffixes, generations)
]
def chat_completion(
self,
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[ChatPrediction]:
if self.tokenizer.step_id is not None:
return self._chat_completion_turns(
dialogs=dialogs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
)
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = []
unsafe_requests = []
for dialog in dialogs:
unsafe_requests.append(
any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
)
if dialog[0]["role"] == "system":
dialog = [ # type: ignore
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: List[int] = sum(
[
self.tokenizer.encode(
f"{B_INST} {prompt['content'].strip()} {E_INST} {answer['content'].strip()} ",
bos=True,
eos=True,
)
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
],
[],
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += self.tokenizer.encode(
f"{B_INST} {dialog[-1]['content'].strip()} {E_INST}",
bos=True,
eos=False,
)
prompt_tokens.append(dialog_tokens)
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
)
if logprobs:
assert generation_logprobs is not None
return [
{
"generation": { # type: ignore
"role": "assistant",
"content": self.tokenizer.decode(t)
if not unsafe
else UNSAFE_ERROR,
},
"tokens": [self.tokenizer.token_piece(x) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i, unsafe in zip(
generation_tokens, generation_logprobs, unsafe_requests
)
]
return [
{
"generation": { # type: ignore
"role": "assistant",
"content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
}
}
for t, unsafe in zip(generation_tokens, unsafe_requests)
]
def _chat_completion_turns(
self,
dialogs: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
) -> List[ChatPrediction]:
if self.tokenizer.step_id is None:
raise RuntimeError("Model not suitable for chat_completion_step()")
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
prompt_tokens = []
unsafe_requests = []
for dialog in dialogs:
unsafe_requests.append(
any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
)
# Insert system message if not provided
if dialog[0]["role"] != "system":
dialog = [{"role": "system", "content": ""}] + dialog # type: ignore
dialog_tokens = dialog_prompt_tokens(self.tokenizer, dialog)
prompt_tokens.append(dialog_tokens)
generation_tokens, generation_logprobs = self.generate(
prompt_tokens=prompt_tokens,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
stop_token=self.tokenizer.step_id,
)
if logprobs:
assert generation_logprobs is not None
return [
{
"generation": {
"role": "assistant",
"destination": "user",
"content": self.tokenizer.decode(t)
if not unsafe
else UNSAFE_ERROR,
},
"tokens": [self.tokenizer.token_piece(x) for x in t],
"logprobs": logprobs_i,
}
for t, logprobs_i, unsafe in zip(
generation_tokens, generation_logprobs, unsafe_requests
)
]
return [
{
"generation": {
"role": "assistant",
"destination": "user",
"content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
}
}
for t, unsafe in zip(generation_tokens, unsafe_requests)
]
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def infilling_prompt_tokens(
tokenizer: Tokenizer,
pre: str,
suf: str,
suffix_first: bool = False,
) -> List[int]:
"""
Format and encode an infilling problem.
If `suffix_first` is set, format in suffix-prefix-middle format.
"""
assert tokenizer.prefix_id is not None
assert tokenizer.middle_id is not None
assert tokenizer.suffix_id is not None
if suffix_first:
# format as "<PRE> <SUF>{suf} <MID> {pre}"
return (
[tokenizer.bos_id, tokenizer.prefix_id, tokenizer.suffix_id]
+ tokenizer.encode_infilling(suf)
+ [tokenizer.middle_id]
+ tokenizer.encode(pre, bos=False, eos=False)
)
else:
# format as "<PRE> {pre} <SUF>{suf} <MID>"
return (
[tokenizer.bos_id, tokenizer.prefix_id]
+ tokenizer.encode(pre, bos=False, eos=False)
+ [tokenizer.suffix_id]
+ tokenizer.encode_infilling(suf)
+ [tokenizer.middle_id]
)
def dialog_prompt_tokens(tokenizer: Tokenizer, dialog: Dialog) -> List[int]:
"""
Prompt formatting for multi-turn dialogs.
The dialog is expected to start with a system message and then alternate
between user and assistant messages.
"""
assert tokenizer.step_id is not None
assert all([msg["role"] == "user" for msg in dialog[1::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[2::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
# Format context
dialog_tokens: List[int] = [tokenizer.bos_id]
headers: List[str] = []
for message in dialog:
headers.clear()
headers.append(f"Source: {message['role'].strip()}")
if message.get("destination") is not None:
headers.append(f"Destination: {message['destination'].strip()}")
header = " " + "\n".join(headers)
dialog_tokens += tokenizer.encode(header, bos=False, eos=False)
if message["content"]:
body = "\n\n " + message["content"].strip()
dialog_tokens += tokenizer.encode(body, bos=False, eos=False)
dialog_tokens += [tokenizer.step_id]
# Start of reply
headers.clear()
headers.append("Source: assistant")
headers.append("Destination: user")
header = " " + "\n".join(headers)
dialog_tokens += tokenizer.encode(header, bos=False, eos=False)
dialog_tokens += tokenizer.encode("\n\n ", bos=False, eos=False)
return dialog_tokens