diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 09b3277c7d..42268dc45f 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -279,31 +279,19 @@ def apply(self, prompt: str, **kwargs: str) -> str: return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n" -class Qwen2_5(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" - -class Qwen2_5_Math(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "Please reason step by step, and put your final answer within \\boxed{}." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" - -class QwQ(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" - -class Salamandra(PromptStyle): - def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit. La meva base de coneixement es va actualitzar per última vegada l'agost de 2023." - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" - +class ChatML(PromptStyle): + def __init__(self, model_name: str): + self.model_name = model_name + self.system_messages = { + "qwen2.5": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", + "qwen2.5-math": "Please reason step by step, and put your final answer within \\boxed{}.", + "qwq": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", + "smollm2": "You are a helpful AI assistant named SmolLM, trained by Hugging Face", + "salamandra": "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit." + } -class SmolLM2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" - return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + return f"<|im_start|>system\n{self.system_messages[self.model_name]}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" # Maps prompt style names to PromptStyle classes @@ -329,11 +317,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: "gemma": Gemma, "llama3": Llama3, "olmo": OLMo, - "qwen2.5": Qwen2_5, - "qwen2.5-math": Qwen2_5_Math, - "qwq": QwQ, - "smollm2": SmolLM2, - "salamandra": Salamandra, + "qwen2.5": ChatML("qwen2.5"), + "qwen2.5-math": ChatML("qwen2.5-math"), + "qwq": ChatML("qwq"), + "smollm2": ChatML("smollm2"), + "salamandra": ChatML("salamandra"), } @@ -373,15 +361,15 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: if re.search(r"OLMo.*-hf", model_name): return OLMo() if re.search(r"Qwen2\.5-Math-.*", model_name): - return Qwen2_5_Math() + return ChatML("qwen2.5-math") if re.search(r"Qwen2\.5-.*", model_name): - return Qwen2_5() + return ChatML("qwen2.5") if re.search(r"QwQ-.*", model_name): - return QwQ() + return ChatML("qwq") if re.search(r"SmolLM2.*-Instruct", model_name): - return SmolLM2() + return ChatML("smollm2") if re.search(r"salamandra-.*-instruct", model_name): - return Salamandra() + return ChatML("salamandra") return Default()