diff --git a/docs/model_support.md b/docs/model_support.md
index d85ea6e43..463ea64c0 100644
--- a/docs/model_support.md
+++ b/docs/model_support.md
@@ -42,6 +42,7 @@
loading multiple peft models, you can have them share the base model weights by
setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model
worker.
+- [FlagAlpha/Llama2-Chinese-13b-Chat](https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat)
## How to support a new model
diff --git a/fastchat/conversation.py b/fastchat/conversation.py
index 65055fdc9..8d5baa1f7 100644
--- a/fastchat/conversation.py
+++ b/fastchat/conversation.py
@@ -936,6 +936,21 @@ def get_conv_template(name: str) -> Conversation:
)
)
+# Llama2-Chinese default template
+# source: https://huggingface.co/FlagAlpha
+register_conv_template(
+ Conversation(
+ name="llama2-chinese",
+ system_message="{system_message}",
+ roles=("Human", "Assistant", "System"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
+ sep="\n",
+ sep2="\n",
+ stop_str="",
+ )
+)
if __name__ == "__main__":
print("Vicuna template:")
diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py
index d25d79a54..cea4b07be 100644
--- a/fastchat/model/model_adapter.py
+++ b/fastchat/model/model_adapter.py
@@ -1382,6 +1382,31 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("aquila-chat")
+class Lamma2ChineseAdapter(BaseModelAdapter):
+ """The model adapter for FlagAlpha/LLama2-Chinese sft"""
+
+ def match(self, model_path: str):
+ return "llama2-chinese" in model_path.lower()
+
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
+ revision = from_pretrained_kwargs.get("revision", "main")
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ revision=revision,
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+ def get_default_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("llama2-chinese")
+
+
# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(PeftModelAdapter)
@@ -1432,6 +1457,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(QwenChatAdapter)
register_model_adapter(AquilaChatAdapter)
register_model_adapter(BGEAdapter)
+register_model_adapter(Lamma2ChineseAdapter)
# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)
diff --git a/fastchat/model/model_registry.py b/fastchat/model/model_registry.py
index d27d12447..5003f634d 100644
--- a/fastchat/model/model_registry.py
+++ b/fastchat/model/model_registry.py
@@ -248,3 +248,9 @@ def get_model_info(name: str) -> ModelInfo:
"https://huggingface.co/Qwen/Qwen-7B-Chat",
"Qwen is a multi-language large-scale language model (LLM), developed by Damo Academy.",
)
+register_model_info(
+ ["Llama2-Chinese-13b-Chat", "LLama2-Chinese-13B"],
+ "Llama2-Chinese",
+ "https://huggingface.co/FlagAlpha/Llama2-Chinese-13b-Chat",
+ "Llama2-Chinese is a multi-language large-scale language model (LLM), developed by FlagAlpha.",
+)