Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine generate() functions #1675

Merged
merged 40 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e9c4616
WIP
apaz-cli Aug 14, 2024
c103b16
Update tests/test_chat.py
apaz-cli Aug 14, 2024
ad2cd48
Update tests/test_chat.py
apaz-cli Aug 14, 2024
906113c
Progress.
apaz-cli Aug 15, 2024
328d81c
Cleanup.
apaz-cli Aug 15, 2024
9b33ad8
Fixed tests.
apaz-cli Aug 15, 2024
0752df9
Test Cleanup
apaz-cli Aug 15, 2024
63a3a47
More Test Cleanup
apaz-cli Aug 15, 2024
2b80cd3
Stub out test for removed function.
apaz-cli Aug 15, 2024
6bd381a
Remove extra import.
apaz-cli Aug 15, 2024
7a62774
Merge branch 'main' into ap/combine_generage
rasbt Aug 15, 2024
0a016b1
Update comments and fix tests
apaz-cli Aug 16, 2024
28454b3
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
a4fc1c8
Merge branch 'main' into ap/combine_generage
rasbt Aug 16, 2024
4f3048e
Cleaned up tests.
apaz-cli Aug 16, 2024
f5c0094
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
ffc8228
fix kv cache bug
rasbt Aug 16, 2024
35115a5
Wrote test_decode, fixed subtle type hints.
apaz-cli Aug 16, 2024
2d09f22
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
5ad8416
Revert "Merge branch 'main' into ap/combine_generage"
rasbt Aug 16, 2024
8ccd293
Added extra comment.
apaz-cli Aug 16, 2024
b5409e1
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
7e9ee03
Merge branch 'main' into ap/combine_generage
rasbt Aug 16, 2024
40f13c1
Merge branch 'main' into ap/combine_generage
rasbt Aug 19, 2024
81e1b3b
Merge branch 'main' of https://github.com/lightning-ai/litgpt into ap…
apaz-cli Aug 20, 2024
860f6cf
Fix incorrect output.
apaz-cli Aug 20, 2024
60c9619
Update litgpt/generate/base.py
apaz-cli Aug 21, 2024
0648421
Merge branch 'main' into ap/combine_generage
rasbt Aug 22, 2024
5c8223a
Update tests/test_api.py
rasbt Aug 22, 2024
461be3f
Update README.md
apaz-cli Aug 22, 2024
b7f7a63
Update README.md
apaz-cli Aug 22, 2024
f46c8dc
Update litgpt/api.py
apaz-cli Aug 22, 2024
264be3b
Update litgpt/api.py
apaz-cli Aug 22, 2024
e935d13
Update litgpt/api.py
apaz-cli Aug 22, 2024
4358fe4
Cleaned up.
apaz-cli Aug 22, 2024
fc87603
Remove accidentally pushed file.
apaz-cli Aug 22, 2024
43f73a4
Remove accidentally pushed file.
apaz-cli Aug 22, 2024
362db0b
Cleanup.
apaz-cli Aug 22, 2024
b9c633e
Fixed input_pos dtype not to depend on torch default dtype.
apaz-cli Aug 22, 2024
f7f4848
Update litgpt/generate/base.py
rasbt Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 23 additions & 59 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,63 +61,18 @@ def generate(
or https://huyenchip.com/2024/01/16/sampling.html#top_p
stop_tokens: If specified, stop generating any more token once one of this list is generated.
"""
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

device = prompt.device
buffer_length = max((len(tokens) for tokens in stop_tokens), default=1)
yield_i = 0
input_pos = torch.arange(0, T, device=device)
tokens = []
token = prompt
for t in range(1, max_returned_tokens - T + 1):
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p)
tokens.append(token)
# check the stop condition
if any((l := len(st)) <= len(tokens) and all(a == b for a, b in zip(tokens[-l:], st)) for st in stop_tokens):
return
# if the buffer is full
if t - yield_i >= buffer_length:
# we know this idx is not part of stop tokens, safe to yield
yield from tokens[yield_i:t]
yield_i = t
input_pos = input_pos[-1:].add_(1)


def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.Tensor]) -> int:
tokens_generated = 0
if tokenizer.backend == "huggingface":
try:
for token in token_stream:
fabric.print(tokenizer.decode(token), end="", flush=True)
tokens_generated += 1
except KeyboardInterrupt:
# support stopping generation
return tokens_generated
elif tokenizer.backend == "sentencepiece":
# sentencepiece does not support decoding token-by-token because it adds spaces based on the surrounding tokens
# meaning that we need to decode everything each time
so_far = torch.tensor([], dtype=torch.long, device=fabric.device)
decoded_so_far = ""
try:
for token in token_stream:
so_far = so_far.to(device=token.device)
so_far = torch.cat((so_far, token.view(-1)))
decoded_new = tokenizer.decode(so_far)
fabric.print(decoded_new[len(decoded_so_far) :], end="", flush=True)
decoded_so_far = decoded_new
tokens_generated += 1
except KeyboardInterrupt:
# support stopping generation
return tokens_generated
else:
raise NotImplementedError(tokenizer.backend)
return tokens_generated
from litgpt.generate.base import generate_fn
return generate_fn(
include_prompt=False,
include_eos=False,
model=model,
prompt=prompt,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens
)


def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
Expand All @@ -133,13 +88,22 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
model.max_seq_length = max_returned_tokens
model.set_kv_cache(batch_size=1, device=fabric.device)

y = generate(
y: Iterator[torch.Tensor] = generate(
model, encoded_prompt, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
token_generator: Iterator[str] = tokenizer.decode_stream(y)

fabric.print(">> Reply: ", end="")

t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)

tokens_generated = 0
for tok in token_generator:
tokens_generated += 1
fabric.print(tok, end="", flush=True)

t = time.perf_counter() - t0

for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
Expand Down
121 changes: 94 additions & 27 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from pathlib import Path
from pprint import pprint
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Tuple, List, Union, Iterator
import warnings

import lightning as L
Expand Down Expand Up @@ -79,6 +79,86 @@ def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: A
return next.to(dtype=x.dtype)


@torch.inference_mode()
def generate_fn(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
stop_tokens: Tuple[List[int], ...] = (),
include_prompt: bool,
include_eos: bool,
) -> Iterator[torch.Tensor]:
prompt_size = prompt.size(0)
device = prompt.device

assert max_returned_tokens > prompt_size, f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}."
rasbt marked this conversation as resolved.
Show resolved Hide resolved
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

# Yield the prompt if include_prompt is True
if include_prompt:
yield prompt

stop_progress = [0] * len(stop_tokens)
yielded_idx = 0

# Generate output tokens.
# The first token generated is the prefill token.
# The input_pos for this token is the width of the entire prompt.
# For subsequent iterations, it's the index in the context for the token that we're generating.
tokens = []
token = prompt
prefill_token = True
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
for current_idx in range(max_returned_tokens - prompt_size):

# Generate the token
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p)
tokens.append(token)
int_token = token.item()

# Check for stop sequences
# For each stop sequence, we keep a running total of how many are matched in stop_progress.
# If the current token matches the next token in the stop sequence, we increment the
# running total and hold off on yielding the token.
for i, seq in enumerate(stop_tokens):
if int_token == seq[stop_progress[i]]:
stop_progress[i] += 1
if stop_progress[i] == len(seq):
if include_eos:
yield from tokens[yielded_idx:]
return
else:
stop_progress[i] = 0

# Yield tokens that are not part of a stop sequence in progress.
# If there are no stop sequences, then that's all of them.
if stop_tokens:
safe_idx = len(tokens) - max(stop_progress)
else:
safe_idx = current_idx + 1 # include the token just generated

if yielded_idx < safe_idx:
y_tokens = tokens[yielded_idx : safe_idx]
yield from y_tokens
yielded_idx = safe_idx

# Update input_pos for the next iteration.
if prefill_token:
prefill_token = False
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
else:
input_pos.add_(1)

# Yield any remaining tokens
if yielded_idx < len(tokens):
yield from tokens[yielded_idx:]


@torch.inference_mode()
def generate(
model: GPT,
Expand Down Expand Up @@ -118,33 +198,20 @@ def generate(
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
"""
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

device = prompt.device
if include_prompt:
tokens = [prompt]
else:
tokens = []
input_pos = torch.tensor([T], device=device)
token = next_token(
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)
if token == eos_id:
break
input_pos = input_pos.add_(1)
return torch.cat(tokens)
token_list = list(generate_fn(
include_prompt=include_prompt,
include_eos=True,
model=model,
prompt=prompt,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=(([eos_id],) if eos_id is not None else ())
))

return torch.cat(token_list) if not len(token_list) == 0 else torch.Tensor()


@torch.inference_mode()
Expand Down
29 changes: 28 additions & 1 deletion litgpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, Iterable, Iterator

import torch

Expand Down Expand Up @@ -136,3 +136,30 @@ def decode(self, tensor: torch.Tensor) -> str:
dummy_token = self.processor.decode([dummy_token_id])
return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :]
return self.processor.decode(tokens)

def decode_stream(self, token_stream: Iterable[torch.Tensor]) -> Iterator[str]:
if self.backend == "huggingface":
try:
for token in token_stream:
yield self.decode(token)
except KeyboardInterrupt:
return
elif self.backend == "sentencepiece":
# TODO: Is there a way to not have to do this?
# This may actually affect our tokens per second.

# sentencepiece does not support decoding token-by-token because it adds spaces based on the surrounding tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Andrei-Aksionov reimplemented the tokenizer pipeline and may have ideas here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I think I didn't test the hack fix in decode method (with a dummy_token_id) for SentencePiece tokenizer.
So maybe now the logic below is not needed.

# meaning that we need to decode everything each time
so_far = torch.tensor([], dtype=torch.long, device=fabric.device)
decoded_so_far = ""
try:
for token in token_stream:
so_far = so_far.to(device=token.device)
so_far = torch.cat((so_far, token.view(-1)))
decoded_new = self.decode(so_far)
yield decoded_new[len(decoded_so_far) :]
decoded_so_far = decoded_new
except KeyboardInterrupt:
return
else:
raise NotImplementedError(self.backend)
8 changes: 6 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
calculate_number_of_devices,
benchmark_dict_to_markdown_table
)

from litgpt.scripts.download import download_from_hub



@pytest.fixture
def mock_llm():
llm = MagicMock(spec=LLM)
Expand Down Expand Up @@ -117,7 +119,8 @@ def test_llm_load_hub_init(tmp_path):
assert len(text_1) > 0

text_2 = llm.generate("text", max_new_tokens=10, top_k=1, stream=True)
assert text_1 == "".join(list(text_2))
text_2 = "".join(list(text_2))
assert text_1 == text_2, (text1, text_2)


def test_model_not_initialized(tmp_path):
Expand Down Expand Up @@ -236,7 +239,8 @@ def test_quantization_is_applied(tmp_path):
model="EleutherAI/pythia-14m",
)
llm.distribute(devices=1, quantize="bnb.nf4", precision="bf16-true")
assert "NF4Linear" in str(type(llm.model.lm_head))
strtype = str(type(llm.model.lm_head))
assert "NF4Linear" in strtype, strtype


@RunIf(min_cuda_gpus=1)
Expand Down
Loading
Loading