diff --git a/docs/model_support.md b/docs/model_support.md index d85ea6e43..463ea64c0 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -42,6 +42,7 @@ loading multiple peft models, you can have them share the base model weights by setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model worker. +- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat) ## How to support a new model diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 65055fdc9..8d5baa1f7 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -936,6 +936,21 @@ def get_conv_template(name: str) -> Conversation: ) ) +# Llama2-Chinese default template +# source: https://huggingface.co/FlagAlpha +register_conv_template( + Conversation( + name="llama2-chinese", + system_message="{system_message}", + roles=("Human", "Assistant", "System"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep="\n", + sep2="\n", + stop_str="", + ) +) if __name__ == "__main__": print("Vicuna template:") diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index d25d79a54..cea4b07be 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -1382,6 +1382,31 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("aquila-chat") +class Lamma2ChineseAdapter(BaseModelAdapter): + """The model adapter for FlagAlpha/LLama2-Chinese sft""" + + def match(self, model_path: str): + return "llama2-chinese" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + revision=revision, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama2-chinese") + + # Note: the registration order matters. # The one registered earlier has a higher matching priority. register_model_adapter(PeftModelAdapter) @@ -1432,6 +1457,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(QwenChatAdapter) register_model_adapter(AquilaChatAdapter) register_model_adapter(BGEAdapter) +register_model_adapter(Lamma2ChineseAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter) diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py index d27d12447..5003f634d 100644 --- a/fastchat/model/model_registry.py +++ b/fastchat/model/model_registry.py @@ -248,3 +248,9 @@ def get_model_info(name: str) -> ModelInfo: "https://huggingface.co/Qwen/Qwen-7B-Chat", "Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.", ) +register_model_info( + ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"], + "Llama2-Chinese", + "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat", + "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.", +)