Skip to content

Commit

Permalink
misc bug fixes: 1> bug fix for model loader, 2> yi chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Nov 23, 2023
1 parent 68854da commit 376875c
Show file tree
Hide file tree
Showing 18 changed files with 102 additions and 48 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ docker run -it --gpus=all --net=host --shm-size=1g \

| Models | Tensor Parallel | Quantization | Chat API | HF models examples |
| :--------: | :-------------: | :----------: | :------: | :---------------------------:|
| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K), [casperhansen/yi-6b-awq](https://huggingface.co/casperhansen/yi-6b-awq), [TheBloke/Yi-34B-GPTQ](https://huggingface.co/TheBloke/Yi-34B-GPTQ) |
| Yi | Yes | Yes | Yes |[01-ai/Yi-6B](https://huggingface.co/01-ai/Yi-6B), [01-ai/Yi-34B-Chat-4bits](https://huggingface.co/01-ai/Yi-34B-Chat-4bits), [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K) |
| Llama2 | Yes | Yes | Yes | [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), [TheBloke/Llama-2-13B-chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-13B-chat-GPTQ), [TheBloke/Llama-2-70B-AWQ](https://huggingface.co/TheBloke/Llama-2-70B-AWQ) |
| Aquila | Yes | Yes | Yes | [BAAI/Aquila-7B](https://huggingface.co/BAAI/Aquila-7B), [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) |
| Bloom | Yes | Yes | No | [bigscience/bloom](https://huggingface.co/bigscience/bloom) |
Expand All @@ -69,7 +69,7 @@ docker run -it --gpus=all --net=host --shm-size=1g \
| GPT2 | Yes | Yes | No | [gpt2](https://huggingface.co/gpt2)|
| InternLM | Yes | Yes | Yes | [internlm/internlm-7b](https://huggingface.co/internlm/internlm-7b) |
| Mistral | Yes | Yes | Yes | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
| MPT | Yes | Yes | No | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) |
| MPT | Yes | Yes | Yes | [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b) |

If your model is not included in the supported list, we are more than willing to assist you. Please feel free to create a request for adding a new model on [GitHub Issues](https://github.com/vectorch-ai/ScaleLLM/issues).

Expand Down
4 changes: 2 additions & 2 deletions src/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ cc_library(
input_parameters.h
model_registry.h
causal_lm.h
dialog.h
conversation.h
SRCS
model_registry.cpp
causal_lm.cpp
dialog.cpp
conversation.cpp
DEPS
:common
:layers
Expand Down
8 changes: 4 additions & 4 deletions src/models/dialog.cpp → src/models/conversation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "dialog.h"
#include "conversation.h"

#include <cstdint>
#include <optional>
Expand All @@ -9,7 +9,7 @@
namespace llm {

// add all messages to the conversation one by one
void Dialog::add_message(Role role, const std::string_view& message) {
void Conversation::add_message(Role role, const std::string_view& message) {
switch (role) {
case Role::User:
if (last_message_role_ == Role::User) {
Expand Down Expand Up @@ -42,8 +42,8 @@ void Dialog::add_message(Role role, const std::string_view& message) {
}
}

// generate prompt from dialogs
std::optional<std::string> Llama2Dialog::get_prompt() const {
// generate prompt from Conversation
std::optional<std::string> Llama2Conversation::get_prompt() const {
// at least one user message
if (messages_.size() % 2 == 0) {
return std::nullopt;
Expand Down
10 changes: 5 additions & 5 deletions src/models/dialog.h → src/models/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

namespace llm {

// dialog only supports 'system', 'user' and 'assistant' roles.
// Conversation only supports 'system', 'user' and 'assistant' roles.
// start with system message, then user and assistant message. (u/a/u/a/u...)
class Dialog {
class Conversation {
public:
Dialog() = default;
virtual ~Dialog() = default;
Conversation() = default;
virtual ~Conversation() = default;

enum class Role : int8_t {
User = 0,
Expand All @@ -36,7 +36,7 @@ class Dialog {
};

// dialog conversation for llama2 model
class Llama2Dialog final : public Dialog {
class Llama2Conversation final : public Conversation {
public:
// generate prompt from dialogs
std::optional<std::string> get_prompt() const override;
Expand Down
7 changes: 4 additions & 3 deletions src/models/huggingface/aquila.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "layers/normalization.h"
#include "memory/kv_cache.h"
#include "models/args.h"
#include "models/dialog.h"
#include "models/conversation.h"
#include "models/input_parameters.h"
#include "models/model_registry.h"

Expand Down Expand Up @@ -363,7 +363,7 @@ class AquilaForCausalLMImpl : public torch::nn::Module {
};
TORCH_MODULE(AquilaForCausalLM);

class AquilaDialog final : public Dialog {
class AquilaConversation final : public Conversation {
public:
// generate prompt from dialogs
// Prompt template for Aquila:
Expand Down Expand Up @@ -395,12 +395,13 @@ class AquilaDialog final : public Dialog {

// register the model to make it available
REGISTER_CAUSAL_MODEL(aquila, AquilaForCausalLM);
REGISTER_DIALOG(aquila, AquilaDialog);
REGISTER_CONVERSATION_TEMPLATE(aquila, AquilaConversation);
REGISTER_MODEL_ARGS(aquila, [&] {
// example config:
// https://huggingface.co/BAAI/Aquila-7B/blob/main/config.json.
// set default values for args explicitly with values from:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/configuration_mistral.py#L104
LOAD_ARG_OR(model_type, "model_type", "aquila");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 100008);
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
Expand Down
1 change: 1 addition & 0 deletions src/models/huggingface/bloom.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ REGISTER_CAUSAL_MODEL(bloom, BloomForCausalLM);
REGISTER_MODEL_ARGS(bloom, [&] {
// example config:
// https://huggingface.co/bigscience/bloom/blob/main/config.json
LOAD_ARG_OR(model_type, "model_type", "bloom");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 250880);
LOAD_ARG_OR(hidden_size, "n_embed", 14336);
Expand Down
1 change: 1 addition & 0 deletions src/models/huggingface/gpt2.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ TORCH_MODULE(GPT2ForCausalLM);
// register the model to make it available
REGISTER_CAUSAL_MODEL(gpt2, GPT2ForCausalLM);
REGISTER_MODEL_ARGS(gpt2, [&] {
LOAD_ARG_OR(model_type, "model_type", "gpt2");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 50257);
LOAD_ARG_OR(hidden_size, "n_embd", 768);
Expand Down
1 change: 1 addition & 0 deletions src/models/huggingface/gpt_j.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ REGISTER_MODEL_ARGS(gptj, [&] {
// https://huggingface.co/EleutherAI/gpt-j-6b/blob/main/config.json set
// default values for args explicitly with values from:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/configuration_gptj.py#L98
LOAD_ARG_OR(model_type, "model_type", "gptj");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 50400);
LOAD_ARG_OR(hidden_size, "n_embd", 4096);
Expand Down
1 change: 1 addition & 0 deletions src/models/huggingface/gpt_neox.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ REGISTER_MODEL_ARGS(gpt_neox, [&] {
// https://huggingface.co/EleutherAI/gpt-neox-20b/blob/main/config.json set
// set default values for args explicitly with values from:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/configuration_gpt_neox.py#L106
LOAD_ARG_OR(model_type, "model_type", "gpt_neox");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 50432);
LOAD_ARG_OR(hidden_size, "hidden_size", 6144);
Expand Down
8 changes: 6 additions & 2 deletions src/models/huggingface/internlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class InternlmForCausalLMImpl : public torch::nn::Module {
};
TORCH_MODULE(InternlmForCausalLM);

class InternlmDialog final : public Dialog {
class InternlmDialog final : public Conversation {
public:
// generate prompt from dialogs
// Prompt template:
Expand Down Expand Up @@ -384,8 +384,9 @@ class InternlmDialog final : public Dialog {

// register the model to make it available
REGISTER_CAUSAL_MODEL(internlm, InternlmForCausalLM);
REGISTER_DIALOG(internlm, InternlmDialog);
REGISTER_CONVERSATION_TEMPLATE(internlm, InternlmDialog);
REGISTER_MODEL_ARGS(internlm, [&] {
LOAD_ARG_OR(model_type, "model_type", "internlm");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 103168);
LOAD_ARG_OR(hidden_size, "hidden_size", 5120);
Expand All @@ -399,6 +400,9 @@ REGISTER_MODEL_ARGS(internlm, [&] {
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2);
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f);

// stop token ids: [1, 103028]
LOAD_ARG_OR(stop_token_ids, "", std::unordered_set<int32_t>({1, 103028}));
});

} // namespace llm::hf
5 changes: 3 additions & 2 deletions src/models/huggingface/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "layers/normalization.h"
#include "memory/kv_cache.h"
#include "models/args.h"
#include "models/dialog.h"
#include "models/conversation.h"
#include "models/input_parameters.h"
#include "models/model_registry.h"

Expand Down Expand Up @@ -365,13 +365,14 @@ TORCH_MODULE(LlamaForCausalLM);

// register the causal model
REGISTER_CAUSAL_MODEL(llama, LlamaForCausalLM);
REGISTER_DIALOG(llama, Llama2Dialog);
REGISTER_CONVERSATION_TEMPLATE(llama, Llama2Conversation);
// register the model args
// example config:
// https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json set
// default values for args explicitly with values from:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/configuration_llama.py#L112
REGISTER_MODEL_ARGS(llama, [&] {
LOAD_ARG_OR(model_type, "model_type", "llama");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 32000);
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
Expand Down
5 changes: 3 additions & 2 deletions src/models/huggingface/mistral.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ class MistralForCausalLMImpl : public torch::nn::Module {
};
TORCH_MODULE(MistralForCausalLM);

class MistralDialog final : public Dialog {
class MistralConversation final : public Conversation {
public:
// generate prompt from dialogs
// Prompt template:
Expand Down Expand Up @@ -398,8 +398,9 @@ class MistralDialog final : public Dialog {

// register the model to make it available
REGISTER_CAUSAL_MODEL(mistral, MistralForCausalLM);
REGISTER_DIALOG(mistral, MistralDialog);
REGISTER_CONVERSATION_TEMPLATE(mistral, MistralConversation);
REGISTER_MODEL_ARGS(mistral, [&] {
LOAD_ARG_OR(model_type, "model_type", "mistral");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 32000);
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
Expand Down
36 changes: 36 additions & 0 deletions src/models/huggingface/mpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,41 @@ class MPTForCausalLMImpl : public torch::nn::Module {
};
TORCH_MODULE(MPTForCausalLM);


class MPTConversation final : public Conversation {
public:
// Prompt template:
// <|im_start|>system\n {system_message} <|im_end|>\n
// <|im_start|>user\n {message} <|im_end|>\n
// <|im_start|>assistant\n
std::optional<std::string> get_prompt() const override {
// at least one user message
if (messages_.size() % 2 == 0) {
return std::nullopt;
}

std::stringstream ss;
if (!system_message_.empty()) {
ss << "<|im_start|>system\n" << system_message_ << "<|im_end|>\n";
}

// then user and assistant message pairs (u/a/u/a/u...)
for (size_t i = 0; i < messages_.size(); ++i) {
const char* role = (i % 2) == 0 ? "user" : "assistant";
ss << "<|im_start|>" << role << "\n" << messages_[i] << "<|im_end|>\n";
}
// end with assistant message
ss << "<|im_start|>assistant\n";
return ss.str();
}
};

// register the causal model
REGISTER_CAUSAL_MODEL(mpt, MPTForCausalLM);
REGISTER_CONVERSATION_TEMPLATE(mpt, MPTConversation);

REGISTER_MODEL_ARGS(mpt, [&] {
LOAD_ARG_OR(model_type, "model_type", "mpt");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 50368);
LOAD_ARG_OR(hidden_size, "d_model", 2048);
Expand All @@ -476,6 +509,9 @@ REGISTER_MODEL_ARGS(mpt, [&] {
json.value_or<int64_t>("expansion_ratio", 4);
return expansion_ratio * args->hidden_size();
});

// stop token ids: [0, 50278]
LOAD_ARG_OR(stop_token_ids, "", std::unordered_set<int32_t>({0, 50278}));
});

} // namespace llm::hf
13 changes: 8 additions & 5 deletions src/models/huggingface/yi.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "layers/normalization.h"
#include "memory/kv_cache.h"
#include "models/args.h"
#include "models/dialog.h"
#include "models/conversation.h"
#include "models/input_parameters.h"
#include "models/model_registry.h"

Expand Down Expand Up @@ -364,7 +364,7 @@ class YiForCausalLMImpl : public torch::nn::Module {
};
TORCH_MODULE(YiForCausalLM);

class YiDialog final : public Dialog {
class YiConversation final : public Conversation {
public:
// generate prompt from dialogs
// https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json#L60
Expand All @@ -378,7 +378,9 @@ class YiDialog final : public Dialog {
}

std::stringstream ss;
// Sounds Yi model doesn't support system message?
if (!system_message_.empty()) {
ss << "<|im_start|>system\n" << system_message_ << "<|im_end|>\n";
}

// then user and assistant message pairs (u/a/u/a/u...)
for (size_t i = 0; i < messages_.size(); ++i) {
Expand All @@ -393,11 +395,12 @@ class YiDialog final : public Dialog {

// register the causal model
REGISTER_CAUSAL_MODEL(Yi, YiForCausalLM);
REGISTER_DIALOG(Yi, YiDialog);
REGISTER_CONVERSATION_TEMPLATE(Yi, YiConversation);
// register the model args
// example config:
// https://huggingface.co/01-ai/Yi-6B/blob/main/config.json
REGISTER_MODEL_ARGS(Yi, [&] {
LOAD_ARG_OR(model_type, "model_type", "Yi");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 64000);
LOAD_ARG_OR(hidden_size, "hidden_size", 4096);
Expand All @@ -412,7 +415,7 @@ REGISTER_MODEL_ARGS(Yi, [&] {
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2);
LOAD_ARG_OR(rope_theta, "rope_theta", 5000000.0f);
LOAD_ARG_OR(rope_scaling, "rope_scaling", 1.0f);

// stop token ids: "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
LOAD_ARG_OR(stop_token_ids, "", std::unordered_set<int32_t>({2, 6, 7, 8}));
});
Expand Down
4 changes: 2 additions & 2 deletions src/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "layers/normalization.h"
#include "memory/kv_cache.h"
#include "models/args.h"
#include "models/dialog.h"
#include "models/conversation.h"
#include "models/input_parameters.h"
#include "models/model_registry.h"

Expand Down Expand Up @@ -362,7 +362,7 @@ TORCH_MODULE(LlamaForCausalLM);

// register the model to make it available
REGISTER_CAUSAL_MODEL(llama2, LlamaForCausalLM);
REGISTER_DIALOG(llama2, Llama2Dialog);
REGISTER_CONVERSATION_TEMPLATE(llama2, Llama2Conversation);

REGISTER_MODEL_ARGS(llama2, [&] {
LOAD_ARG_OR(dtype, "torch_dtype", torch::toString(torch::kBFloat16));
Expand Down
17 changes: 10 additions & 7 deletions src/models/model_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ void ModelRegistry::register_quant_args_loader(const std::string& name,
}
}

void ModelRegistry::register_dialog_factory(const std::string& name,
DialogFactory factory) {
void ModelRegistry::register_conversation_template(
const std::string& name,
ConversationTemplate factory) {
ModelRegistry* instance = get_instance();
if (instance->model_registry_[name].dialog_factory != nullptr) {
GLOG(WARNING) << "dialog factory for " << name << "already registered.";
if (instance->model_registry_[name].conversation_template != nullptr) {
GLOG(WARNING) << "conversation template for " << name
<< "already registered.";
} else {
instance->model_registry_[name].dialog_factory = factory;
instance->model_registry_[name].conversation_template = factory;
}
}

Expand All @@ -77,9 +79,10 @@ QuantArgsLoader ModelRegistry::get_quant_args_loader(const std::string& name) {
return instance->model_registry_[name].quant_args_loader;
}

DialogFactory ModelRegistry::get_dialog_factory(const std::string& name) {
ConversationTemplate ModelRegistry::get_conversation_template(
const std::string& name) {
ModelRegistry* instance = get_instance();
return instance->model_registry_[name].dialog_factory;
return instance->model_registry_[name].conversation_template;
}

} // namespace llm
Loading

0 comments on commit 376875c

Please sign in to comment.