-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.py
118 lines (108 loc) · 4.2 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import tqdm
import torch
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from transformers import StoppingCriteria, StoppingCriteriaList
class EndOfFunctionCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
def __init__(self, start_length, eof_strings, tokenizer):
self.start_length = start_length
self.eof_strings = eof_strings
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the end-of-function strings."""
decoded_generations = self.tokenizer.batch_decode(
input_ids[:, self.start_length :]
)
done = []
for decoded_generation in decoded_generations:
done.append(
any(
[
stop_string in decoded_generation
for stop_string in self.eof_strings
]
)
)
return all(done)
class TokenizedDataset(IterableDataset):
"""Tokenize and preprocess the dataset, where the dataset is a list of instructions (str)"""
def __init__(self, tokenizer, dataset):
self.tokenizer = tokenizer
self.dataset = dataset
self.outputs = self.tokenizer(self.dataset, padding=True, return_tensors="pt")
def __iter__(self):
for i in range(len(self.dataset)):
yield {
"input_ids": self.outputs.input_ids[i],
"attention_mask": self.outputs.attention_mask[i],
"index_prompt": torch.tensor(i, dtype=torch.int32),
}
def hf_generate(
accelerator,
model,
tokenizer,
prompts,
max_new_tokens,
temperature,
top_p,
stop_words,
num_beams,
repetition_penalty,
num_return_sequences,
do_sample,
forced_bos_token_id,
):
accelerator.free_memory()
results = []
if isinstance(prompts, list):
pass
else:
# single prompt, i.e str
prompts = [prompts]
tokenized_dataset = TokenizedDataset(tokenizer=tokenizer, dataset=prompts)
dataloader = DataLoader(tokenized_dataset, batch_size=1)
dataloader = accelerator.prepare(dataloader)
pad_first = tokenizer.padding_side == "left"
for step, batch in tqdm.tqdm(enumerate(dataloader)):
with torch.no_grad():
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
index_prompt = batch["index_prompt"]
stopping_criteria = StoppingCriteriaList(
[EndOfFunctionCriteria(attention_mask.sum(), stop_words, tokenizer)]
)
response = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
stopping_criteria=stopping_criteria,
do_sample=do_sample,
num_return_sequences=num_return_sequences,
forced_bos_token_id=forced_bos_token_id,
)
padded_responses = accelerator.pad_across_processes(
response, dim=1, pad_index=tokenizer.pad_token_id, pad_first=pad_first
)
padded_attention_mask = accelerator.pad_across_processes(
attention_mask, dim=1, pad_index=0, pad_first=pad_first
)
indices = accelerator.gather(index_prompt)
answers = accelerator.gather(padded_responses)
padded_attention_mask = accelerator.gather(padded_attention_mask)
for i in range(accelerator.num_processes):
results.append(
{
"prompt": prompts[indices[i]],
"answer": tokenizer.decode(
answers[i], skip_special_tokens=True,
),
}
)
accelerator.free_memory()
return results