Skip to content

Commit

Permalink
ChatML prompt template
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Dec 22, 2024
1 parent 7e12d64 commit 1667f29
Showing 1 changed file with 21 additions and 33 deletions.
54 changes: 21 additions & 33 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,31 +279,19 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n"


class Qwen2_5(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

class Qwen2_5_Math(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "Please reason step by step, and put your final answer within \\boxed{}."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

class QwQ(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

class Salamandra(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit. La meva base de coneixement es va actualitzar per última vegada l'agost de 2023."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

class ChatML(PromptStyle):
def __init__(self, model_name: str):
self.model_name = model_name
self.system_messages = {
"qwen2.5": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
"qwen2.5-math": "Please reason step by step, and put your final answer within \\boxed{}.",
"qwq": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
"smollm2": "You are a helpful AI assistant named SmolLM, trained by Hugging Face",
"salamandra": "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit."
}

class SmolLM2(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face"
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
return f"<|im_start|>system\n{self.system_messages[self.model_name]}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


# Maps prompt style names to PromptStyle classes
Expand All @@ -329,11 +317,11 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"gemma": Gemma,
"llama3": Llama3,
"olmo": OLMo,
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
"smollm2": SmolLM2,
"salamandra": Salamandra,
"qwen2.5": ChatML("qwen2.5"),
"qwen2.5-math": ChatML("qwen2.5-math"),
"qwq": ChatML("qwq"),
"smollm2": ChatML("smollm2"),
"salamandra": ChatML("salamandra"),
}


Expand Down Expand Up @@ -373,15 +361,15 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
if re.search(r"OLMo.*-hf", model_name):
return OLMo()
if re.search(r"Qwen2\.5-Math-.*", model_name):
return Qwen2_5_Math()
return ChatML("qwen2.5-math")
if re.search(r"Qwen2\.5-.*", model_name):
return Qwen2_5()
return ChatML("qwen2.5")
if re.search(r"QwQ-.*", model_name):
return QwQ()
return ChatML("qwq")
if re.search(r"SmolLM2.*-Instruct", model_name):
return SmolLM2()
return ChatML("smollm2")
if re.search(r"salamandra-.*-instruct", model_name):
return Salamandra()
return ChatML("salamandra")
return Default()


Expand Down

0 comments on commit 1667f29

Please sign in to comment.