Skip to content

Commit

Permalink
Merge branch 'main' into ap/combine_generage
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Aug 15, 2024
2 parents 6bd381a + 8b6789e commit 7a62774
Show file tree
Hide file tree
Showing 6 changed files with 657 additions and 66 deletions.
117 changes: 97 additions & 20 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,41 @@
from pathlib import Path
import sys
import time
from typing import Any, List, Literal, Optional, Union
from typing import Any, Callable, List, Literal, Optional, Union, Tuple

from tqdm import tqdm
import torch
import lightning as L
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.accelerators import CUDAAccelerator


from litgpt.model import GPT
from litgpt.config import name_to_config, Config
from litgpt.tokenizer import Tokenizer
from litgpt.generate.sequentially import sequential
from litgpt.generate.tp import tensor_parallel
from litgpt.generate.base import generate as generate_fn
from litgpt.chat.base import generate as stream_generate_fn
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.prompts import (
load_prompt_style,
has_prompt_style,
save_prompt_style,
PromptStyle
)
from litgpt.utils import (
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
check_nvlink_connectivity,
chunked_cross_entropy,
copy_config_files,
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
save_config,
)


class LLM:
class LLM(torch.nn.Module):
def __init__(
self,
model: GPT,
Expand All @@ -45,7 +52,8 @@ def __init__(
kv_cache_initialized: bool = False,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None
) -> None:
self._model = model
super().__init__()
self.model = model
self.preprocessor = preprocessor
self.devices = devices
self.prompt_style = prompt_style
Expand All @@ -67,16 +75,77 @@ def __init__(
text = llm.generate("What do Llamas eat?", top_k=1)
print(text)
"""

@property
def model(self):
if self._model is None:
raise AttributeError("The model is not initialized yet; use the .distribute() method to initialize the model.")
return self._model
def tokenizer(self):
return self.preprocessor.tokenizer

@model.setter
def model(self, content):
self._model = content
def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)

def forward(
self,
input_ids: torch.Tensor,
target_ids: Optional[torch.Tensor] = None,
loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
logits = self.model(input_ids)
if target_ids is not None:
if loss_fn is None:
loss_fn = chunked_cross_entropy
loss = loss_fn(logits[..., :-1, :], target_ids[..., 1:])
return logits, loss
else:
return logits

def trainer_setup(self, trainer_ckpt: Optional[Path] = None) -> None:
"""Initializes the model checkpoint for PyTorch Lightning Trainer contexts"""
self.model = GPT(self.config)

if trainer_ckpt is not None:
# strip the object name key from the state_dict
state_dict = torch.load(trainer_ckpt, weights_only=True)["state_dict"]
first_key = next(iter(state_dict))
prefix = first_key.split(".")[0] + "."
keys_to_modify = [key for key in state_dict if key.startswith(prefix)]
for key in keys_to_modify:
new_key = key.replace(prefix, "", 1)
state_dict[new_key] = state_dict.pop(key)

self.load_state_dict(state_dict, strict=True)

elif self.checkpoint_dir is not None:
state_dict = torch.load(self.checkpoint_dir / "lit_model.pth", weights_only=False)
self.load_state_dict(state_dict, strict=False)

else:
raise ValueError(
"No checkpoint found. Either provide a valid path via `trainer_ckpt` "
"or ensure that `self.checkpoint_dir` points to a folder containing a `lit_model.pth` weight file."
)

def save(self, out_dir: Optional[Path] = None, prompt_style: Optional[PromptStyle] = None) -> None:
out_dir = Path(out_dir)
save_path = out_dir / "lit_model.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)

if prompt_style is None:
prompt_style = PromptStyle.from_config(self.config)
if self.fabric is None:
torch.save(self.state_dict(), save_path)
else:
self.fabric.save(save_path, self.state_dict())

if self.fabric is None or self.fabric.global_rank == 0:
# If initialization a model with random weights, the checkpoint dir can be none
if self.checkpoint_dir is not None:
copy_config_files(Path(self.checkpoint_dir), save_path.parent)
else:
save_config(self.config, out_dir)

save_prompt_style(prompt_style, save_path.parent)

@classmethod
def load(
Expand Down Expand Up @@ -107,7 +176,7 @@ def load(
allowed_init = {"pretrained", "random"}

if init == "pretrained":
checkpoint_dir = auto_download_checkpoint(model_name=model, access_token=access_token)
checkpoint_dir = auto_download_checkpoint(model_name=model, access_token=access_token, ignore_tokenizer_files=tokenizer_dir is not None)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

elif init == "random":
Expand All @@ -118,7 +187,7 @@ def load(
print(f"Model name {model} is not supported.\n")
available_models = "\n".join(sorted(name_to_config))
print(f"Available values:\n{available_models}")
quit()
return

else:
raise ValueError(f"Invalid init option: {init}. Must be one of {allowed_init}")
Expand Down Expand Up @@ -187,7 +256,7 @@ def distribute(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
generate_strategy: Optional[Literal["sequential", "tensor_parallel"]] = None,
fixed_kv_cache_size: Union[int, Literal["max_model_supported"], None] = None
):
) -> None:
"""
Moves the model onto specified devices for single-GPU or multi-GPU inference
Expand Down Expand Up @@ -401,19 +470,27 @@ def generate(
At the moment, this setting is slower and may use more memory than the non-streaming version.
We plan to resolve this in the future.
"""
assert self.model is not None
if self.model is None:
raise AttributeError(
"The model is not initialized yet; use the .distribute() "
"or .trainer_setup() method to initialize the model."
)
input_ids = self._text_to_token_ids(prompt)
prompt_length = input_ids.size(0)
max_returned_tokens = prompt_length + max_new_tokens

if not self.kv_cache_initialized:
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=self.fabric.device)
if self.fabric is not None:
device = self.fabric.device
else:
device = self.preprocessor.device
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)
self.kv_cache_initialized = True

# Dynamically grow the kv cache size if necessary
if self.fixed_kv_cache_size is None and self.prev_generated_seq_length < max_returned_tokens:
self.model.clear_kv_cache()
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=self.fabric.device)
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)

else:
for block in self.model.transformer.h:
Expand Down Expand Up @@ -493,7 +570,7 @@ def benchmark(self, **kwargs):
benchmark_dict["Seconds to first token"] = time_to_first_token
benchmark_dict["Tokens generated"] = self.preprocessor.encode(outputs).size(0) - self._text_to_token_ids(kwargs.get("prompt")).size(0)
benchmark_dict["Inference speed in tokens/sec"] = benchmark_dict["Tokens generated"] / benchmark_dict["Seconds total"]
if self.fabric.device.type == "cuda":
if self.fabric is not None and self.fabric.device.type == "cuda":
benchmark_dict["Total GPU memory allocated in GB"] = torch.cuda.max_memory_allocated() / 1e9

return outputs, benchmark_dict
Expand Down
23 changes: 17 additions & 6 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,25 @@ def reset_parameters(module: nn.Module) -> None:
mod.reset_parameters()


def check_valid_checkpoint_dir(checkpoint_dir: Path, model_filename: str = "lit_model.pth", verbose: bool = True, raise_error: bool = False) -> None:
def check_valid_checkpoint_dir(
checkpoint_dir: Path,
model_filename: str = "lit_model.pth",
verbose: bool = True,
raise_error: bool = False,
ignore_tokenizer_files: bool = False
) -> None:

files = {
model_filename: (checkpoint_dir / model_filename).is_file(),
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
or (checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
}
if not ignore_tokenizer_files:
files.update({
"tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or
(checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
})

if checkpoint_dir.is_dir():
if all(files.values()):
# we're good
Expand Down Expand Up @@ -574,12 +585,12 @@ def check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_71
return size


def auto_download_checkpoint(model_name, access_token=None):
def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_files=False):
from litgpt.scripts.download import download_from_hub # moved here due to circular import issue

checkpoint_dir = extend_checkpoint_dir(Path(model_name))
try:
check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True)
check_valid_checkpoint_dir(checkpoint_dir, verbose=False, raise_error=True, ignore_tokenizer_files=ignore_tokenizer_files)
except FileNotFoundError as e:
if access_token is None:
access_token = os.getenv("HF_TOKEN")
Expand Down
87 changes: 77 additions & 10 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from pathlib import Path
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from collections import OrderedDict
import os
from pathlib import Path

import pytest
import re
import torch
from unittest.mock import MagicMock
from tests.conftest import RunIf

from litgpt.api import LLM, calculate_number_of_devices
from litgpt.scripts.download import download_from_hub
from tests.conftest import RunIf



@pytest.fixture
Expand Down Expand Up @@ -118,10 +123,11 @@ def test_model_not_initialized(tmp_path):
init="pretrained",
distribute=None
)
with pytest.raises(AttributeError, match=re.escape("The model is not initialized yet; use the .distribute() method to initialize the model.")):
llm.model

with pytest.raises(AttributeError, match=re.escape("The model is not initialized yet; use the .distribute() method to initialize the model.")):
s = (
"The model is not initialized yet; use the .distribute() "
"or .trainer_setup() method to initialize the model."
)
with pytest.raises(AttributeError, match=re.escape(s)):
llm.generate("text")

llm = LLM.load(
Expand All @@ -130,10 +136,11 @@ def test_model_not_initialized(tmp_path):
init="random",
distribute=None
)
with pytest.raises(AttributeError, match=re.escape("The model is not initialized yet; use the .distribute() method to initialize the model.")):
llm.model

with pytest.raises(AttributeError, match=re.escape("The model is not initialized yet; use the .distribute() method to initialize the model.")):
s = (
"The model is not initialized yet; use the .distribute() "
"or .trainer_setup() method to initialize the model."
)
with pytest.raises(AttributeError, match=re.escape(s)):
llm.generate("text")


Expand Down Expand Up @@ -180,6 +187,23 @@ def test_sequential_tp_cpu(tmp_path):
)


def test_initialization_for_trainer(tmp_path):
llm = LLM.load(
model="EleutherAI/pythia-14m",
distribute=None
)
s = (
"The model is not initialized yet; use the .distribute() "
"or .trainer_setup() method to initialize the model."
)
with pytest.raises(AttributeError, match=re.escape(s)):
llm.generate("hello world")

llm.trainer_setup()
llm.model.to(llm.preprocessor.device)
assert isinstance(llm.generate("hello world"), str)


@RunIf(min_cuda_gpus=1)
def test_quantization_is_applied(tmp_path):
llm = LLM.load(
Expand Down Expand Up @@ -219,3 +243,46 @@ def test_returned_benchmark_dir(tmp_path):

text, bench_d = llm.benchmark(prompt="hello world", stream=True)
assert isinstance(bench_d["Inference speed in tokens/sec"], float)


def test_state_dict(tmp_path):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
assert isinstance(llm.state_dict(), OrderedDict)
assert llm.state_dict()['lm_head.weight'].shape == torch.Size([50304, 128])


def test_save_method(tmp_path):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)

target_dir = "saved_model"
llm.save(target_dir)

expected_files = [
"config.json",
"generation_config.json",
"lit_model.pth",
"model_config.yaml",
"prompt_style.yaml",
"tokenizer_config.json",
"tokenizer.json"
]

files_in_directory = os.listdir(target_dir)
for file_name in expected_files:
assert file_name in files_in_directory, f"{file_name} is missing from {target_dir}"


def test_forward_method(tmp_path):
llm = LLM.load(
model="EleutherAI/pythia-14m",
)
inputs = torch.ones(6, 128, dtype=torch.int64).to(next(llm.model.parameters()).device)

assert llm(inputs).shape == torch.Size([6, 128, 50304])
logits, loss = llm(inputs, target_ids=inputs)
assert logits.shape == torch.Size([6, 128, 50304])
assert isinstance(loss.item(), float)
Loading

0 comments on commit 7a62774

Please sign in to comment.