-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
172 lines (147 loc) · 5.35 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Main code to run inference."""
import warnings
from typing import Dict
from tqdm import tqdm
from PIL import Image
import torch
import fire
from dotenv import load_dotenv
from src.processors.paligemma_processor import PaliGemmaProcessor
from src.utils.kv_cache import KVCache
from src.models.paligemma import PaliGemmaForConditionalGeneration
from src.utils.model_loader import load_hf_model
def move_inputs_to_device(model_inputs: Dict[str, torch.Tensor],
device: str
) -> Dict[str, torch.Tensor]:
"""Move inputs to device"""
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs
def get_model_inputs(
processor: PaliGemmaProcessor,
prompt: str,
image_file_path: str,
device: str,
) -> Dict[str, torch.Tensor]:
"""Get model inputs"""
image = Image.open(image_file_path)
images = [image]
prompts = [prompt]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs
def test_inference(
model: PaliGemmaForConditionalGeneration,
processor: PaliGemmaProcessor,
device: str,
prompt: str,
image_file_path: str,
max_tokens_to_generate: int,
temperature: float,
top_p: float,
do_sample: bool
) -> None:
"""Run inference"""
model_inputs = get_model_inputs(processor, prompt, image_file_path, device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
pixel_values = model_inputs["pixel_values"]
kv_cache = KVCache()
# Generate tokens until you see the stop token
stop_token = processor.tokenizer.eos_token_id
generated_tokens = []
for _ in tqdm(range(max_tokens_to_generate),desc="Tokens generated"):
# Get the model outputs
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
kv_cache=kv_cache,
)
kv_cache = outputs["kv_cache"]
next_token_logits = outputs["logits"][:, -1, :]
# Sample the next token
if do_sample:
# Apply temperature
next_token_logits = torch.softmax(
next_token_logits / temperature, dim=-1)
next_token = _sample_top_p(next_token_logits, top_p)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
assert next_token.size() == (1, 1)
next_token = next_token.squeeze(0) # Remove batch dimension
generated_tokens.append(next_token)
# Stop if the stop token has been generated
if next_token.item() == stop_token:
break
# Append the next token to the input
input_ids = next_token.unsqueeze(-1)
attention_mask = torch.cat(
[attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1
)
generated_tokens = torch.cat(generated_tokens, dim=-1)
# Decode the generated tokens
decoded = processor.tokenizer.decode(
generated_tokens, skip_special_tokens=True)
print("The output :",prompt + decoded)
def _sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample token using top_p"""
# (B, vocab_size)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# (B, vocab_size)
probs_sum = torch.cumsum(probs_sort, dim=-1)
# (B, vocab_size)
# (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
mask = probs_sum - probs_sort > p
# Zero out all the probabilities of tokens that are not selected by the
# Top P
probs_sort[mask] = 0.0
# Redistribute the probabilities so that they sum up to 1.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
# Sample a token (its index) from the top p distribution
next_token = torch.multinomial(probs_sort, num_samples=1)
# Get the token position in the vocabulary corresponding to the sampled
# index
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def main(
model_id: str,
model_path: str,
prompt: str,
image_file_path: str,
max_tokens_to_generate: int = 100,
temperature: float = 0.8,
top_p: float = 0.9,
do_sample: bool = False,
only_cpu: bool = False,
) -> None:
"""Main process"""
device = "cpu"
if not only_cpu:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
print("Device in use:", device)
print("Loading model")
model, tokenizer = load_hf_model(model_id, model_path, device)
model = model.to(device).eval()
num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
processor = PaliGemmaProcessor(tokenizer, num_image_tokens, image_size)
print("Running inference")
with torch.no_grad():
test_inference(
model,
processor,
device,
prompt,
image_file_path,
max_tokens_to_generate,
temperature,
top_p,
do_sample,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore")
print("Is there .env file:",load_dotenv())
fire.Fire(main)