diff --git a/litgpt/api.py b/litgpt/api.py index 9969ad5927..a66da5a618 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -163,8 +163,9 @@ def load( with fabric.init_module(empty_init=(num_devices > 1)): model = GPT(config) - with fabric.init_tensor(): - model.set_kv_cache(batch_size=1) + # This should be set if we add a compile feature later + # with fabric.init_tensor(): + # model.set_kv_cache(batch_size=1) model.eval() model = fabric.setup_module(model) @@ -178,6 +179,7 @@ def load( prompt_style=prompt_style, checkpoint_dir=checkpoint_dir, fabric=fabric, ) + @torch.inference_mode() def generate( self, prompt: str, @@ -221,6 +223,11 @@ def generate( prompt_length = input_ids.size(0) max_returned_tokens = prompt_length + max_new_tokens + first_turn = self.model.mask_cache is None + if first_turn or max_returned_tokens > self.model.max_seq_length: + self.model.max_seq_length = max_returned_tokens + self.model.set_kv_cache(batch_size=1, device=self.fabric.device) + self.model.eval() if calculate_number_of_devices(self.devices) > 1: diff --git a/tests/test_api.py b/tests/test_api.py index c155858cc3..d15fe958e8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -95,8 +95,24 @@ def test_llm_load_random_init(tmp_path): init="random", tokenizer_dir=Path(tmp_path/"EleutherAI/pythia-14m") ) - text = llm.generate("text", max_new_tokens=10) - assert len(text.split(" ")) > 5 + + input_text = "some text text" + output_text = llm.generate(input_text, max_new_tokens=15) + ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text)) + assert ln <= 15 + + # The following below tests that generate works with different prompt lengths + # after the kv cache was set + + input_text = "some text" + output_text = llm.generate(input_text, max_new_tokens=15) + ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text)) + assert ln <= 15 + + input_text = "some text text text" + output_text = llm.generate(input_text, max_new_tokens=15) + ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text)) + assert ln <= 15 def test_llm_load_hub_init(tmp_path):