diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index aaa62423c..13c5b4ab5 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,6 +30,7 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta +from axolotl.utils.chat_templates import chat_templates from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -234,7 +235,8 @@ def do_inference_gradio( model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + # default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + default_tokens: Dict[str, str] = {} for token, symbol in default_tokens.items(): # If the token isn't already specified in the config, add it @@ -242,10 +244,13 @@ def do_inference_gradio( tokenizer.add_special_tokens({token: symbol}) prompter_module = None + chat_template_str = None if prompter: prompter_module = getattr( importlib.import_module("axolotl.prompters"), prompter ) + elif cfg.chat_template: + chat_template_str = chat_templates(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) @@ -259,7 +264,24 @@ def generate(instruction): ) else: prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() with torch.no_grad(): @@ -282,6 +304,7 @@ def generate(instruction): streamer = TextIteratorStreamer(tokenizer) generation_kwargs = { "inputs": batch["input_ids"].to(cfg.device), + "attention_mask": batch["attention_mask"].to(cfg.device), "generation_config": generation_config, "streamer": streamer, } diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f4cd25783..7c3e437f8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1417,6 +1417,8 @@ def build(self, total_num_steps): report_to = [] if self.cfg.use_wandb: report_to.append("wandb") + if self.cfg.wandb_name: + training_arguments_kwargs["run_name"] = self.cfg.wandb_name if self.cfg.use_mlflow: report_to.append("mlflow") if self.cfg.use_tensorboard: @@ -1574,6 +1576,12 @@ def build(self, total_num_steps): ) training_args = self.hook_post_create_training_args(training_args) + # unset run_name so wandb sets up experiment names + if self.cfg.use_wandb and training_args.run_name == training_args.output_dir: + training_args.run_name = ( # pylint: disable=attribute-defined-outside-init + None + ) + data_collator_kwargs = { "padding": True, # True/"longest" is the default } diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 717367eef..88e748895 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -375,8 +375,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): prompter_params = { "tokenizer": tokenizer, "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), - "message_field_role": ds_cfg.get("message_field_role", "from"), - "message_field_content": ds_cfg.get("message_field_content", "value"), + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( "message_field_training_detail", diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 458bacdb1..221785508 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1017,12 +1017,20 @@ def validate_neftune_noise_alpha(cls, neftune_noise_alpha): return neftune_noise_alpha @model_validator(mode="after") - def check(self): + def check_rl_beta(self): if self.dpo_beta and not self.rl_beta: self.rl_beta = self.dpo_beta del self.dpo_beta return self + @model_validator(mode="after") + def check_simpo_warmup(self): + if self.rl == "simpo" and self.warmup_ratio: + raise ValueError( + "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" + ) + return self + @model_validator(mode="before") @classmethod def check_frozen(cls, data):