Skip to content

Commit

Permalink
fix for protected model_ namespace w pydantic (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Feb 28, 2024
1 parent 3a5a2d2 commit 6b3b271
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 22 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub
model_revision:
revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
Expand All @@ -573,7 +573,7 @@ is_qwen_derived_model:
is_mistral_derived_model:

# optional overrides to the base model configuration
model_config_overrides:
overrides_of_model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
Expand Down
12 changes: 6 additions & 6 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def normalize_config(cfg):
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
)

# figure out if the model is falcon
Expand All @@ -140,7 +140,7 @@ def normalize_config(cfg):
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
)

cfg.is_mistral_derived_model = (
Expand All @@ -153,7 +153,7 @@ def normalize_config(cfg):
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower().split("/")[-1]
or (cfg.model_type and "mistral" in cfg.model_type.lower())
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
)

cfg.is_qwen_derived_model = (
Expand Down Expand Up @@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)

if cfg.gptq and cfg.model_revision:
if cfg.gptq and cfg.revision_of_model:
raise ValueError(
"model_revision is not supported for GPTQ models. "
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
+ "point to its path, and remove revision_of_model from the config."
)

# if cfg.sample_packing and cfg.sdp_attention:
Expand Down
26 changes: 19 additions & 7 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
return noisy_embedding_alpha


class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""

overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")


class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset"""

Expand Down Expand Up @@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
model_type: Optional[str] = Field(default=None)
model_revision: Optional[str] = None
trust_remote_code: Optional[bool] = None

model_config_overrides: Optional[Dict[str, Any]] = None

@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
Expand Down Expand Up @@ -362,11 +368,17 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
):
"""wrapper of all config options"""

class Config:
"""Config for alias"""

populate_by_name = True

strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
Expand Down Expand Up @@ -550,11 +562,11 @@ def check_pretraining_w_group_by_length(cls, data):
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
if data.get("gptq") and data.get("model_revision"):
if data.get("gptq") and data.get("revision_of_model"):
raise ValueError(
"model_revision is not supported for GPTQ models. "
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
+ "point to its path, and remove revision_of_model from the config."
)
return data

Expand Down
14 changes: 7 additions & 7 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def load_model_config(cfg):
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.model_revision:
config_kwargs["revision"] = cfg.model_revision
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model

try:
model_config = AutoConfig.from_pretrained(
Expand All @@ -104,8 +104,8 @@ def load_model_config(cfg):
)
raise err

if cfg.model_config_overrides:
for key, val in cfg.model_config_overrides.items():
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val)

check_model_config(cfg, model_config)
Expand Down Expand Up @@ -272,7 +272,7 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.model_type
model_type = cfg.type_of_model
model_config = load_model_config(cfg)

# TODO refactor as a kwarg
Expand Down Expand Up @@ -426,8 +426,8 @@ def load_model(
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]

if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.revision_of_model:
model_kwargs["revision"] = cfg.revision_of_model
if cfg.gptq:
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
Expand Down
42 changes: 42 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
import warnings
from typing import Optional

import pytest
Expand All @@ -14,6 +15,8 @@
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars

warnings.filterwarnings("error")


@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
Expand Down Expand Up @@ -190,6 +193,45 @@ def test_lr_as_float(self, minimal_cfg):

assert new_cfg.learning_rate == 0.00005

def test_model_config_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_config": {"model_type": "mistral"},
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg.overrides_of_model_config["model_type"] == "mistral"

def test_model_type_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_type": "AutoModelForCausalLM",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg.type_of_model == "AutoModelForCausalLM"

def test_model_revision_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_revision": "main",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg.revision_of_model == "main"

def test_qlora(self, minimal_cfg):
base_cfg = (
DictDefault(
Expand Down

0 comments on commit 6b3b271

Please sign in to comment.