Skip to content

Commit

Permalink
Add SmolLM2 (#1848)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrei-Aksionov <aksionau.andrei@gmail.com>
Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 16, 2024
1 parent 972dee4 commit 7b26d35
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
Expand Down
76 changes: 75 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,10 @@ def norm_class(self) -> Type:

configs.extend(qwq)


#############
# Salamandra
#############

salamandra = [
# https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json
dict(
Expand Down Expand Up @@ -2189,4 +2189,78 @@ def norm_class(self) -> Type:
configs.append(copy)


###############
# SmolLM2
###############
smollm2 = [
# https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json
dict(
name="SmolLM2-135M{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=30,
n_head=9,
n_embd=576,
n_query_groups=3,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=1536,
rope_base=100000,
norm_eps=1e-5,
),
# https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json
dict(
name="SmolLM2-360M{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=32,
n_head=15,
n_embd=960,
n_query_groups=5,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=2560,
rope_base=100000,
norm_eps=1e-5,
),
# https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json
dict(
name="SmolLM2-1.7B{}",
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"),
block_size=8192,
vocab_size=49152,
padded_vocab_size=49152,
n_layer=24,
n_head=32,
n_embd=2048,
n_query_groups=32,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=130000,
norm_eps=1e-5,
),
]

for c in smollm2:
for kind in ("", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


name_to_config = {config["name"]: config for config in configs}
9 changes: 9 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


class SmolLM2(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face"
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


# Maps prompt style names to PromptStyle classes
prompt_styles: Dict[str, Type[PromptStyle]] = {
# Dataset-specific prompt styles
Expand All @@ -326,6 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
"smollm2": SmolLM2,
"salamandra": Salamandra,
}

Expand Down Expand Up @@ -371,6 +378,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Qwen2_5()
if re.search(r"QwQ-.*", model_name):
return QwQ()
if re.search(r"SmolLM2.*-Instruct", model_name):
return SmolLM2()
if re.search(r"salamandra-.*-instruct", model_name):
return Salamandra()
return Default()
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s
with gated_repo_catcher(repo_id, access_token):
info = repo_info(repo_id, token=access_token)
filenames = [f.rfilename for f in info.siblings]
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"]))
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
return bins, safetensors

Expand Down
4 changes: 3 additions & 1 deletion litgpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int:
raise ValueError(f"token {token!r} not found in the collection.")
return id_

def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
return False
with open(tokenizer_config_path, encoding="utf-8") as fp:
Expand All @@ -96,6 +96,8 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
# `PreTrainedTokenizerFast`
if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")):
return True
if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"):
return True
if "add_bos_token" in config:
return config["add_bos_token"]
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
Expand Down
61 changes: 61 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype):
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
@pytest.mark.parametrize(
Expand Down Expand Up @@ -910,6 +911,66 @@ def test_against_original_salamandra(model_name, device, dtype):
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_smollm2(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(dynamo=True)
Expand Down
7 changes: 7 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
Expand Down Expand Up @@ -122,6 +123,12 @@ google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
h2oai/h2o-danube2-1.8b-chat
HuggingFaceTB/SmolLM2-135M
HuggingFaceTB/SmolLM2-135M-Instruct
HuggingFaceTB/SmolLM2-360M
HuggingFaceTB/SmolLM2-360M-Instruct
HuggingFaceTB/SmolLM2-1.7B
HuggingFaceTB/SmolLM2-1.7B-Instruct
lmsys/longchat-13b-16k
lmsys/longchat-7b-16k
lmsys/vicuna-13b-v1.3
Expand Down

0 comments on commit 7b26d35

Please sign in to comment.