Skip to content

Commit

Permalink
lint after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 2, 2024
1 parent c0f437c commit 1e993f7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/attention/mllama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code

from typing import Optional, Tuple

import torch
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
import importlib
import math
Expand Down
21 changes: 15 additions & 6 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,38 @@ def normalize_config(cfg):

if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model

model_config = load_model_config(cfg)

cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)

cfg.is_multimodal = (
hasattr(model_config, "model_type") and model_config.model_type in [ "llava", "mllama"]
or any(multimodal_name in cfg.base_model.lower() for multimodal_name in [ "pixtral", ])
or cfg.is_multimodal
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
"pixtral",
]
)
or cfg.is_multimodal
)
if cfg.is_multimodal:
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config

cfg.model_config_type = model_config.model_type
cfg.model_config_type = model_config.model_type

# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == ["llama", "mllama_text_model"])
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
)
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
Expand Down

0 comments on commit 1e993f7

Please sign in to comment.