Skip to content

Commit

Permalink
Fixes an issue where the LitGPT Python API was consuming too much mem…
Browse files Browse the repository at this point in the history
…ory (#1590)
  • Loading branch information
rasbt authored Jul 17, 2024
1 parent ef51b9a commit a9b758f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
11 changes: 9 additions & 2 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ def load(
with fabric.init_module(empty_init=(num_devices > 1)):
model = GPT(config)

with fabric.init_tensor():
model.set_kv_cache(batch_size=1)
# This should be set if we add a compile feature later
# with fabric.init_tensor():
# model.set_kv_cache(batch_size=1)

model.eval()
model = fabric.setup_module(model)
Expand All @@ -178,6 +179,7 @@ def load(
prompt_style=prompt_style, checkpoint_dir=checkpoint_dir, fabric=fabric,
)

@torch.inference_mode()
def generate(
self,
prompt: str,
Expand Down Expand Up @@ -221,6 +223,11 @@ def generate(
prompt_length = input_ids.size(0)
max_returned_tokens = prompt_length + max_new_tokens

first_turn = self.model.mask_cache is None
if first_turn or max_returned_tokens > self.model.max_seq_length:
self.model.max_seq_length = max_returned_tokens
self.model.set_kv_cache(batch_size=1, device=self.fabric.device)

self.model.eval()

if calculate_number_of_devices(self.devices) > 1:
Expand Down
20 changes: 18 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,24 @@ def test_llm_load_random_init(tmp_path):
init="random",
tokenizer_dir=Path(tmp_path/"EleutherAI/pythia-14m")
)
text = llm.generate("text", max_new_tokens=10)
assert len(text.split(" ")) > 5

input_text = "some text text"
output_text = llm.generate(input_text, max_new_tokens=15)
ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))
assert ln <= 15

# The following below tests that generate works with different prompt lengths
# after the kv cache was set

input_text = "some text"
output_text = llm.generate(input_text, max_new_tokens=15)
ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))
assert ln <= 15

input_text = "some text text text"
output_text = llm.generate(input_text, max_new_tokens=15)
ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))
assert ln <= 15


def test_llm_load_hub_init(tmp_path):
Expand Down

0 comments on commit a9b758f

Please sign in to comment.