Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine generate() functions #1675

Merged
merged 40 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e9c4616
WIP
apaz-cli Aug 14, 2024
c103b16
Update tests/test_chat.py
apaz-cli Aug 14, 2024
ad2cd48
Update tests/test_chat.py
apaz-cli Aug 14, 2024
906113c
Progress.
apaz-cli Aug 15, 2024
328d81c
Cleanup.
apaz-cli Aug 15, 2024
9b33ad8
Fixed tests.
apaz-cli Aug 15, 2024
0752df9
Test Cleanup
apaz-cli Aug 15, 2024
63a3a47
More Test Cleanup
apaz-cli Aug 15, 2024
2b80cd3
Stub out test for removed function.
apaz-cli Aug 15, 2024
6bd381a
Remove extra import.
apaz-cli Aug 15, 2024
7a62774
Merge branch 'main' into ap/combine_generage
rasbt Aug 15, 2024
0a016b1
Update comments and fix tests
apaz-cli Aug 16, 2024
28454b3
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
a4fc1c8
Merge branch 'main' into ap/combine_generage
rasbt Aug 16, 2024
4f3048e
Cleaned up tests.
apaz-cli Aug 16, 2024
f5c0094
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
ffc8228
fix kv cache bug
rasbt Aug 16, 2024
35115a5
Wrote test_decode, fixed subtle type hints.
apaz-cli Aug 16, 2024
2d09f22
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
5ad8416
Revert "Merge branch 'main' into ap/combine_generage"
rasbt Aug 16, 2024
8ccd293
Added extra comment.
apaz-cli Aug 16, 2024
b5409e1
Merge branch 'ap/combine_generage' of https://github.com/lightning-ai…
apaz-cli Aug 16, 2024
7e9ee03
Merge branch 'main' into ap/combine_generage
rasbt Aug 16, 2024
40f13c1
Merge branch 'main' into ap/combine_generage
rasbt Aug 19, 2024
81e1b3b
Merge branch 'main' of https://github.com/lightning-ai/litgpt into ap…
apaz-cli Aug 20, 2024
860f6cf
Fix incorrect output.
apaz-cli Aug 20, 2024
60c9619
Update litgpt/generate/base.py
apaz-cli Aug 21, 2024
0648421
Merge branch 'main' into ap/combine_generage
rasbt Aug 22, 2024
5c8223a
Update tests/test_api.py
rasbt Aug 22, 2024
461be3f
Update README.md
apaz-cli Aug 22, 2024
b7f7a63
Update README.md
apaz-cli Aug 22, 2024
f46c8dc
Update litgpt/api.py
apaz-cli Aug 22, 2024
264be3b
Update litgpt/api.py
apaz-cli Aug 22, 2024
e935d13
Update litgpt/api.py
apaz-cli Aug 22, 2024
4358fe4
Cleaned up.
apaz-cli Aug 22, 2024
fc87603
Remove accidentally pushed file.
apaz-cli Aug 22, 2024
43f73a4
Remove accidentally pushed file.
apaz-cli Aug 22, 2024
362db0b
Cleanup.
apaz-cli Aug 22, 2024
b9c633e
Fixed input_pos dtype not to depend on torch default dtype.
apaz-cli Aug 22, 2024
f7f4848
Update litgpt/generate/base.py
rasbt Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
![cpu-tests](https://github.com/lightning-AI/lit-stablelm/actions/workflows/cpu-tests.yml/badge.svg) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lit-stablelm/blob/master/LICENSE) [![Discord](https://img.shields.io/discord/1077906959069626439)](https://discord.gg/VptPCZkGNa)

<p align="center">
<a href="https://lightning.ai/">Lightning AI</a> •
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved
<a href="#quick-start">Quick start</a> •
<a href="#choose-from-20-llms">Models</a> •
<a href="#finetune-an-llm">Finetune</a> •
<a href="#deploy-an-llm">Deploy</a> •
<a href="#all-workflows">All workflows</a> •
<a href="#state-of-the-art-features">Features</a> •
<a href="#training-recipes">Recipes (YAML)</a> •
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved
<a href="https://lightning.ai/">Lightning AI</a> •
<a href="#tutorials">Tutorials</a>
</p>

Expand Down
47 changes: 18 additions & 29 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def load(
def distribute(
self,
accelerator: Literal["cpu", "cuda", "auto"] = "auto",
devices: Union[int, Literal["auto"]] = "auto",
devices: Union[int, List[int]] = 1,
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved
precision: Optional[Any] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
Expand All @@ -262,7 +262,7 @@ def distribute(
Moves the model onto specified devices for single-GPU or multi-GPU inference

accelerator: Which device type to load the model on ("cpu", "gpu", "mps", "cuda", or "auto")
devices: The number of devices (1, 2, etc.) or "auto", which uses all available devices
devices: The number of devices (1, 2, etc.) or device IDs (e.g., [0, 2] to use the first and third GPU).
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved
quantize: Whether to quantize the model and using which method:
- bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
- bnb.int8: 8-bit quantization from bitsandbytes
Expand All @@ -274,7 +274,7 @@ def distribute(
models that wouldn't fit in a single card by partitioning the transformer blocks across
all devices and running them sequentially. Sequential generation may be slower but allows using larger models.
Note that sequential generation sets `fixed_kv_cache_size="max_model_supported"`. You can set it to a lower integer
value, `fixed_kv_cache_size=256` to reduce memory. The `fixed_kv_cache_size` value determines the maximum number
value, `fixed_kv_cache_size=256` to reduce memory memory. The `fixed_kv_cache_size` value determins the maximum number
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved
of tokens that can be returned via `llm.generate(...)`.
fixed_kv_cache_size: If set to an integer value or "max_model_supported" is set, the kv-cache won't be resized dynamically
during `llm.generate` calls. Use this setting if you plan to compile the model or use `generate_strategy="sequential`.
Expand Down Expand Up @@ -302,30 +302,12 @@ def distribute(
if generate_strategy in ("sequential", "tensor_parallel") and accelerator not in ("cuda", "gpu"):
raise NotImplementedError(f"generate_strategy='{generate_strategy}' is only supported for accelerator='cuda'|'gpu'.")

if devices == "auto":
if generate_strategy in ("sequential", "tensor_parallel"):
total_devices = CUDAAccelerator.auto_device_count()
else:
total_devices = 1
elif isinstance(devices, int):
use_devices = calculate_number_of_devices(devices)
total_devices = CUDAAccelerator.auto_device_count()
if use_devices > total_devices:
raise ValueError(
f"You selected more devices ({use_devices}) than available in your system ({total_devices})."
)
else:
total_devices = use_devices
num_devices = calculate_number_of_devices(devices)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why this was changed? Bad rebase maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Not sure when it happened.


if total_devices > 1 and generate_strategy not in ("sequential", "tensor_parallel"):
raise NotImplementedError(
"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."
)

else:
raise ValueError(f"devices argument must be an integer or 'auto', got {devices}")

print(f"Using {total_devices} device(s)", file=sys.stderr)
if generate_strategy is None and num_devices > 1:
raise NotImplementedError(
"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."
)

if precision is None:
precision = get_default_supported_precision(training=False)
Expand All @@ -349,7 +331,7 @@ def distribute(
)
else:
fabric = L.Fabric(
devices=total_devices,
devices=devices,
strategy="ddp",
precision=precision,
plugins=plugins
Expand All @@ -360,7 +342,7 @@ def distribute(

self.kv_cache_initialized = False
if generate_strategy is None:
with fabric.init_module(empty_init=(total_devices > 1)):
with fabric.init_module(empty_init=(num_devices > 1)):
model = GPT(self.config)
model.eval()

Expand All @@ -379,6 +361,14 @@ def distribute(
self.fixed_kv_cache_size = fixed_kv_cache_size

elif generate_strategy in ("sequential", "tensor_parallel"):
total_devices = CUDAAccelerator.auto_device_count()
if devices is not None:
if devices < total_devices:
total_devices = devices
elif devices > total_devices:
raise ValueError(f"This machine only has {total_devices} but you specified `devices={devices}`")

print(f"Using {total_devices} devices", file=sys.stderr)

with fabric.init_tensor(), torch.device("meta"):
model = GPT(self.config)
Expand Down Expand Up @@ -503,7 +493,6 @@ def generate(
tmp_device = self.model.mask_cache.device
self.model.clear_kv_cache()
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)

else:
for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()
Expand Down
82 changes: 23 additions & 59 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,63 +61,18 @@ def generate(
or https://huyenchip.com/2024/01/16/sampling.html#top_p
stop_tokens: If specified, stop generating any more token once one of this list is generated.
"""
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

device = prompt.device
buffer_length = max((len(tokens) for tokens in stop_tokens), default=1)
yield_i = 0
input_pos = torch.arange(0, T, device=device)
tokens = []
token = prompt
for t in range(1, max_returned_tokens - T + 1):
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p)
tokens.append(token)
# check the stop condition
if any((l := len(st)) <= len(tokens) and all(a == b for a, b in zip(tokens[-l:], st)) for st in stop_tokens):
return
# if the buffer is full
if t - yield_i >= buffer_length:
# we know this idx is not part of stop tokens, safe to yield
yield from tokens[yield_i:t]
yield_i = t
input_pos = input_pos[-1:].add_(1)


def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.Tensor]) -> int:
tokens_generated = 0
if tokenizer.backend == "huggingface":
try:
for token in token_stream:
fabric.print(tokenizer.decode(token), end="", flush=True)
tokens_generated += 1
except KeyboardInterrupt:
# support stopping generation
return tokens_generated
elif tokenizer.backend == "sentencepiece":
# sentencepiece does not support decoding token-by-token because it adds spaces based on the surrounding tokens
# meaning that we need to decode everything each time
so_far = torch.tensor([], dtype=torch.long, device=fabric.device)
decoded_so_far = ""
try:
for token in token_stream:
so_far = so_far.to(device=token.device)
so_far = torch.cat((so_far, token.view(-1)))
decoded_new = tokenizer.decode(so_far)
fabric.print(decoded_new[len(decoded_so_far) :], end="", flush=True)
decoded_so_far = decoded_new
tokens_generated += 1
except KeyboardInterrupt:
# support stopping generation
return tokens_generated
else:
raise NotImplementedError(tokenizer.backend)
return tokens_generated
from litgpt.generate.base import generate_fn
return generate_fn(
include_prompt=False,
include_eos=False,
model=model,
prompt=prompt,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens
)


def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):
Expand All @@ -133,13 +88,22 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
model.max_seq_length = max_returned_tokens
model.set_kv_cache(batch_size=1, device=fabric.device)

y = generate(
y: Iterator[torch.Tensor] = generate(
model, encoded_prompt, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
token_generator: Iterator[str] = tokenizer.decode_stream(y)

fabric.print(">> Reply: ", end="")

t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)

tokens_generated = 0
for tok in token_generator:
tokens_generated += 1
fabric.print(tok, end="", flush=True)

t = time.perf_counter() - t0

for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
Expand Down
122 changes: 95 additions & 27 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from pathlib import Path
from pprint import pprint
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Tuple, List, Union, Iterator
import warnings

import lightning as L
Expand Down Expand Up @@ -79,6 +79,87 @@ def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: A
return next.to(dtype=x.dtype)


@torch.inference_mode()
def generate_fn(
model: GPT,
prompt: torch.Tensor,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
stop_tokens: Tuple[List[int], ...] = (),
include_prompt: bool,
include_eos: bool,
) -> Iterator[torch.Tensor]:
prompt_size = prompt.size(0)
device = prompt.device

assert max_returned_tokens > prompt_size, f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}."
rasbt marked this conversation as resolved.
Show resolved Hide resolved
assert max_returned_tokens > prompt_size, f"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}."
rasbt marked this conversation as resolved.
Show resolved Hide resolved
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

# Yield the prompt if include_prompt is True
if include_prompt:
yield prompt

stop_progress = [0] * len(stop_tokens)
yielded_idx = 0

# Generate output tokens.
# The first token generated is the prefill token.
# The input_pos for this token is the width of the entire prompt.
# For subsequent iterations, it's the index in the context for the token that we're generating.
tokens = []
token = prompt
prefill_token = True
input_pos = torch.arange(0, prompt_size, device=device)
for current_idx in range(max_returned_tokens - prompt_size):

# Generate the token
token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p)
tokens.append(token)
int_token = token.item()

# Check for stop sequences
# For each stop sequence, we keep a running total of how many are matched in stop_progress.
# If the current token matches the next token in the stop sequence, we increment the
# running total and hold off on yielding the token.
for i, seq in enumerate(stop_tokens):
if int_token == seq[stop_progress[i]]:
stop_progress[i] += 1
if stop_progress[i] == len(seq):
if include_eos:
yield from tokens[yielded_idx:]
return
else:
stop_progress[i] = 0

# Yield tokens that are not part of a stop sequence in progress.
# If there are no stop sequences, then that's all of them.
if stop_tokens:
safe_idx = len(tokens) - max(stop_progress)
else:
safe_idx = current_idx + 1 # include the token just generated

if yielded_idx < safe_idx:
y_tokens = tokens[yielded_idx : safe_idx]
yield from y_tokens
yielded_idx = safe_idx

# Update input_pos for the next iteration.
if prefill_token:
prefill_token = False
input_pos = torch.tensor([prompt_size], device=device)
else:
input_pos = input_pos.add_(1)
apaz-cli marked this conversation as resolved.
Show resolved Hide resolved

# Yield any remaining tokens
if yielded_idx < len(tokens):
yield from tokens[yielded_idx:]


@torch.inference_mode()
def generate(
model: GPT,
Expand Down Expand Up @@ -118,33 +199,20 @@ def generate(
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
"""
T = prompt.size(0)
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

device = prompt.device
if include_prompt:
tokens = [prompt]
else:
tokens = []
input_pos = torch.tensor([T], device=device)
token = next_token(
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)
for _ in range(2, max_returned_tokens - T + 1):
token = next_token(
model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
).clone()
tokens.append(token)
if token == eos_id:
break
input_pos = input_pos.add_(1)
return torch.cat(tokens)
token_list = list(generate_fn(
include_prompt=include_prompt,
include_eos=True,
model=model,
prompt=prompt,
max_returned_tokens=max_returned_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=(([eos_id],) if eos_id is not None else ())
))

return torch.cat(token_list) if not len(token_list) == 0 else torch.Tensor()


@torch.inference_mode()
Expand Down
Loading
Loading