diff --git a/README.md b/README.md index d64b60804f..6092db7ad8 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ Every model is written from scratch to maximize performance and remove layers of | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) | | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | +| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | diff --git a/litgpt/config.py b/litgpt/config.py index de7fd5df2a..b218df849c 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -149,12 +149,21 @@ def mlp_class(self) -> Type: @property def norm_class(self) -> Type: # `self.norm_class_name` cannot be the type to keep the config serializable - if self.norm_class_name == "RMSNorm": - from functools import partial + from functools import partial + + if self.norm_class_name == "RMSNorm": + from litgpt.model import RMSNorm return partial(RMSNorm, add_unit_offset="Gemma" in self.name) + + if self.norm_class_name == "LayerNorm" and "OLMo" in self.name: + # this makes it equivalent to `torch.nn.functional.layer_norm` + # that is used by OLMo + # Table 5 caption in the OLMo paper shows this - https://aclanthology.org/2024.acl-long.841 + return partial(torch.nn.LayerNorm, elementwise_affine=False) + return getattr(torch.nn, self.norm_class_name) @@ -722,6 +731,64 @@ def norm_class(self) -> Type: rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), ) + +################# +# Allen AI OLMo +################# +olmo = [ + # https://huggingface.co/allenai/OLMo-1B-hf/blob/main/config.json + dict( + name="OLMo-1B-hf", + hf_config=dict(org="allenai", name="OLMo-1B-hf"), + vocab_size=50280, + padded_vocab_size=50304, + block_size=2048, + n_embd=2048, + n_layer=16, + n_head=16, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="LayerNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + ), + # https://huggingface.co/allenai/OLMo-7B-hf/blob/main/config.json + dict( + name="OLMo-7B-hf", + hf_config=dict(org="allenai", name="OLMo-7B-hf"), + vocab_size=50280, + padded_vocab_size=50304, + block_size=2048, + n_layer=32, + n_head=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="LayerNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/allenai/OLMo-7B-Instruct-hf/blob/main/config.json + dict( + name="OLMo-7B-Instruct-hf", + hf_config=dict(org="allenai", name="OLMo-7B-Instruct-hf"), + vocab_size=50280, + padded_vocab_size=50304, + block_size=2048, + n_layer=32, + n_head=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="LayerNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + ), +] + +configs.extend(olmo) + ############### # Google Gemma ############### diff --git a/litgpt/prompts.py b/litgpt/prompts.py index c0b9c2c282..09fb86676c 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -276,6 +276,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: +class OLMo(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n" + + # Maps prompt style names to PromptStyle classes prompt_styles: Dict[str, Type[PromptStyle]] = { # Dataset-specific prompt styles @@ -298,6 +303,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "tinyllama": TinyLlama, "gemma": Gemma, "llama3": Llama3, + "olmo": OLMo, } @@ -334,6 +340,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return TinyLlama() if re.search(r"(Code)?Gemma.*-it", model_name): return Gemma() + if re.search(r"OLMo.*-hf", model_name): + return OLMo() return Default() diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index 0eab0705ba..a81c59aa2d 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -130,7 +130,10 @@ def encode( if eos and (not tokens or tokens[-1] != self.eos_id): tokens = tokens + [self.eos_id] - + # if the processor misbehaves and adds `eos` token no matter what + elif tokens and tokens[-1] == self.eos_id: + tokens = tokens[:-1] + if max_length > 0: tokens = tokens[:max_length] return torch.tensor(tokens, dtype=torch.int, device=device) diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 4ee478d68d..f2e0b48459 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -14,6 +14,7 @@ from transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from transformers.models.olmo import OlmoConfig, OlmoForCausalLM from litgpt import GPT, Config from litgpt.scripts.convert_lit_checkpoint import ( @@ -192,6 +193,48 @@ def test_against_mixtral(): theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) +def test_against_olmo(model_name): + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + intermediate_size=86, + ) + T = 5 + theirs_config = OlmoConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + intermediate_size=ours_config.intermediate_size, + num_hidden_layers=ours_config.n_layer, + num_attention_heads=ours_config.n_head, + num_key_value_heads=ours_config.n_query_groups, + max_positional_embeddings=T, + attention_bias=ours_config.bias, + rope_theta=ours_config.rope_base, + tie_word_embeddings=(model_name == "OLMo-1B-hf"), + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + ours_model = GPT(ours_config) + # tie weights + ours_model.lm_head.weight = ours_model.transformer.wte.weight + ours_state_dict = ours_model.state_dict() + theirs_state_dict = {} + copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=(model_name == "OLMo-1B-hf")) + theirs_model = OlmoForCausalLM(theirs_config) + keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) + assert not keys.unexpected_keys + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"] + torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() def test_against_original_open_llama_3b(): diff --git a/tests/test_model.py b/tests/test_model.py index 4cff66446b..f2ec330f14 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -27,6 +27,7 @@ from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from transformers.models.olmo import OlmoConfig, OlmoForCausalLM import litgpt.config as config_module from litgpt.model import batched_index_copy_ @@ -551,6 +552,63 @@ def test_against_hf_mixtral(): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_olmo(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + intermediate_size=86, + ) + T = 5 + theirs_config = OlmoConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + intermediate_size=ours_config.intermediate_size, + num_hidden_layers=ours_config.n_layer, + num_attention_heads=ours_config.n_head, + num_key_value_heads=ours_config.n_query_groups, + max_positional_embeddings=T, + attention_bias=ours_config.bias, + rope_theta=ours_config.rope_base, + tie_word_embeddings=(model_name == "OLMo-1B-hf"), + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = OlmoForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) @torch.inference_mode() @pytest.mark.parametrize( diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 64cf7e45e6..9ab0041357 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -27,6 +27,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) | | Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) | | Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | +| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) @@ -54,6 +55,9 @@ litgpt download list The output is shown below: ``` +allenai/OLMo-1B-hf +allenai/OLMo-7B-hf +allenai/OLMo-7B-Instruct-hf codellama/CodeLlama-13b-hf codellama/CodeLlama-13b-Instruct-hf codellama/CodeLlama-13b-Python-hf