From 2dc0be818f54297c275b5fff9a10e69a7cf6f881 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 22 Dec 2024 11:47:56 -0500 Subject: [PATCH] Added ChatML inheritance for better typing compatibility --- litgpt/prompts.py | 53 +++++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index a438b60ca5..6e32c4e7eb 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -280,18 +280,31 @@ def apply(self, prompt: str, **kwargs: str) -> str: class ChatML(PromptStyle): - def __init__(self, model_name: str): - self.model_name = model_name - self.system_messages = { - "qwen2.5": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", - "qwen2.5-math": "Please reason step by step, and put your final answer within \\boxed{}.", - "qwq": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", - "smollm2": "You are a helpful AI assistant named SmolLM, trained by Hugging Face", - "salamandra": "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit." - } + def __init__(self, system_message: str): + self.system_message = system_message def apply(self, prompt: str, **kwargs: str) -> str: - return f"<|im_start|>system\n{self.system_messages[self.model_name]}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + return f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + +class Qwen2_5(ChatML): + def __init__(self): + super().__init__("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.") + +class Qwen2_5_Math(ChatML): + def __init__(self): + super().__init__("Please reason step by step, and put your final answer within \\boxed{}.") + +class QwQ(ChatML): + def __init__(self): + super().__init__("You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.") + +class SmolLM2(ChatML): + def __init__(self): + super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face") + +class Salamandra(ChatML): + def __init__(self): + super().__init__("I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit.") # Maps prompt style names to PromptStyle classes @@ -317,11 +330,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: "gemma": Gemma, "llama3": Llama3, "olmo": OLMo, - "qwen2.5": lambda: ChatML("qwen2.5"), - "qwen2.5-math": lambda: ChatML("qwen2.5-math"), - "qwq": lambda: ChatML("qwq"), - "smollm2": lambda: ChatML("smollm2"), - "salamandra": lambda: ChatML("salamandra"), + "qwen2.5": Qwen2_5, + "qwen2.5-math": Qwen2_5_Math, + "qwq": QwQ, + "smollm2": SmolLM2, + "salamandra": Salamandra, } @@ -361,15 +374,15 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: if re.search(r"OLMo.*-hf", model_name): return OLMo() if re.search(r"Qwen2\.5-Math-.*", model_name): - return ChatML("qwen2.5-math") + return Qwen2_5_Math() if re.search(r"Qwen2\.5-.*", model_name): - return ChatML("qwen2.5") + return Qwen2_5() if re.search(r"QwQ-.*", model_name): - return ChatML("qwq") + return QwQ() if re.search(r"SmolLM2.*-Instruct", model_name): - return ChatML("smollm2") + return SmolLM2() if re.search(r"salamandra-.*-instruct", model_name): - return ChatML("salamandra") + return Salamandra() return Default()