Skip to content

Commit

Permalink
Feat/chatml add system message (#1117)
Browse files Browse the repository at this point in the history
* add system message to template

* readme update

* added code to register new system message

* register chatml template for test

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
3 people authored Jan 25, 2024
1 parent ee0b5f6 commit 98b4762
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 13 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chatml_template

LOG = logging.getLogger("axolotl.cli.preprocess")

Expand All @@ -34,6 +35,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
return_remaining_strings=True
)

if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)

if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
Expand Down
8 changes: 7 additions & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import register_chatml_template
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")
Expand All @@ -37,7 +38,12 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl:
if cfg.chat_template == "chatml" and cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {cfg.default_system_message}"
)
register_chatml_template(cfg.default_system_message)

dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
Expand Down
21 changes: 12 additions & 9 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2

register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message="You are a helpful assistant.",
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",

def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
)


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def chat_templates(user_choice: str):

templates = {
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
}

if user_choice in templates:
Expand Down
8 changes: 7 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,13 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")

if cfg.chat_template:
tokenizer.chat_template = chat_templates(cfg.chat_template)
chat_template_string = chat_templates(cfg.chat_template)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)

tokenizer.chat_template = chat_template_string
else:
LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference."
Expand Down
7 changes: 6 additions & 1 deletion tests/prompt_strategies/test_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
from transformers import AutoTokenizer

from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy
from axolotl.prompt_strategies.sharegpt import (
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
)
from axolotl.prompters import ShareGPTPrompterV2

register_chatml_template()


@pytest.fixture(name="sharegpt_dataset")
def fixture_sharegpt_dataset():
Expand Down

0 comments on commit 98b4762

Please sign in to comment.