Skip to content

Commit

Permalink
Fix: loading DBRX back from saved path (#35728)
Browse files Browse the repository at this point in the history
* fix dtype as dict for some models + add test

* add comment in tests
  • Loading branch information
zucchini-nlp authored and ArthurZucker committed Jan 30, 2025
1 parent b17abf9 commit 163c8bb
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
29 changes: 25 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4020,10 +4020,31 @@ def from_pretrained(
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
)
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, torch.dtype):
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, dict):
for key, curr_dtype in torch_dtype.items():
if hasattr(config, key):
value = getattr(config, key)
value.torch_dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config
torch_dtype = torch_dtype.get("")
config.torch_dtype = torch_dtype
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
elif torch_dtype is None:
torch_dtype = torch.float32
else:
raise ValueError(
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
f"for each sub-config in composite configs, but received {torch_dtype}"
)

dtype_orig = cls._set_default_torch_dtype(torch_dtype)

# Check if `_keep_in_fp32_modules` is not None
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dbrx/configuration_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta

for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights

for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]:
for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ def check_save_load(out1, out2):
with torch.no_grad():
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]

# Save and load second time because `from_pretrained` adds a bunch of new config fields
# so we need to make sure those fields can be loaded back after saving
# Simply init as `model(config)` doesn't add those fields
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)

if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_save_load(tensor1, tensor2)
Expand Down
54 changes: 54 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,60 @@ def test_model_from_config_torch_dtype_str(self):
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")

def test_model_from_config_torch_dtype_composite(self):
"""
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
"""
# should be able to set torch_dtype as a simple string and the model loads it correctly
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32)

model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
self.assertEqual(model.language_model.dtype, torch.float16)
self.assertEqual(model.vision_tower.dtype, torch.float16)

# should be able to set torch_dtype as a dict for each sub-config
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)

# should be able to set the values as torch.dtype (not str)
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)

# should be able to set the values in configs directly and pass it to `from_pretrained`
config = copy.deepcopy(model.config)
config.text_config.torch_dtype = torch.float32
config.vision_config.torch_dtype = torch.bfloat16
config.torch_dtype = torch.float16
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)

# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)

# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
)

@require_torch
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):
Expand Down

0 comments on commit 163c8bb

Please sign in to comment.