Skip to content

Commit

Permalink
validation fixes 20240923 (#1925)
Browse files Browse the repository at this point in the history
* validation fixes 20240923

* fix run name for wandb and defaults for chat template fields

* fix gradio inference with llama chat template
  • Loading branch information
winglian authored Sep 24, 2024
1 parent 7b9f669 commit d7eea2f
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
27 changes: 25 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
Expand Down Expand Up @@ -234,18 +235,22 @@ def do_inference_gradio(

model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}

for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})

prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)

model = model.to(cfg.device, dtype=cfg.torch_dtype)

Expand All @@ -259,7 +264,24 @@ def generate(instruction):
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

model.eval()
with torch.no_grad():
Expand All @@ -282,6 +304,7 @@ def generate(instruction):
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,8 @@ def build(self, total_num_steps):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
Expand Down Expand Up @@ -1574,6 +1576,12 @@ def build(self, total_num_steps):
)
training_args = self.hook_post_create_training_args(training_args)

# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)

data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
Expand Down
10 changes: 9 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,12 +1017,20 @@ def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
return neftune_noise_alpha

@model_validator(mode="after")
def check(self):
def check_rl_beta(self):
if self.dpo_beta and not self.rl_beta:
self.rl_beta = self.dpo_beta
del self.dpo_beta
return self

@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
return self

@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
Expand Down

0 comments on commit d7eea2f

Please sign in to comment.