Skip to content

Commit

Permalink
Akoumparouli/mixtral fixes for r2.0.0rc1 (NVIDIA#9911) (NVIDIA#9933)
Browse files Browse the repository at this point in the history
* nemo-ux-mixtral: use cpu init & skip init when importing; specify correct dtype



* nemo-ux-state: handle None in state_dict.keys; disable auto-grad when transforming ckpt



* add dummy SquadDataModule.reconfigure_limit_batches



* Apply isort and black reformatting



---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 30, 2024
1 parent 86bfac2 commit 7dd9378
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
3 changes: 3 additions & 0 deletions nemo/collections/llm/gpt/data/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,6 @@ def _preprocess_and_split_data(
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

def reconfigure_limit_batches(self):
return
9 changes: 8 additions & 1 deletion nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class MixtralConfig8x7B(GPTConfig):
# rotary
rotary_percent: float = 0.5
rotary_base: float = 10000
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16


class MixtralModel(GPTModel):
Expand All @@ -70,7 +72,7 @@ def init(self) -> MixtralModel:
def apply(self, output_path: Path) -> Path:
from transformers import MixtralForCausalLM

source = MixtralForCausalLM.from_pretrained(str(self))
source = MixtralForCausalLM.from_pretrained(str(self), torch_dtype='auto', use_safetensors=True)
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
Expand Down Expand Up @@ -109,6 +111,7 @@ def config(self) -> MixtralConfig8x7B:

config = HfMixtralConfig.from_pretrained(str(self))
return MixtralConfig8x7B(
bf16=getattr(config, "torch_dtype", None) == torch.bfloat16,
activation_func=F.silu,
# network
num_layers=config.num_hidden_layers,
Expand All @@ -132,6 +135,10 @@ def config(self) -> MixtralConfig8x7B:
gated_linear_unit=True,
# Vocab
make_vocab_size_divisible_by=128,
# CPU init
use_cpu_initialization=True,
perform_initialization=False,
params_dtype=getattr(config, "torch_dtype", torch.bfloat16),
)


Expand Down
19 changes: 10 additions & 9 deletions nemo/lightning/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload

import numpy as np
import torch
from torch import nn

SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module)
Expand All @@ -19,11 +20,12 @@ class TransformCTX:
target_state: dict


@torch.no_grad
def apply_transforms(
source: nn.Module,
target: TargetModuleT,
mapping: Dict[str, str],
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
) -> TargetModuleT:
"""
Applies a series of transformations to adapt the state dictionary of a source module to
Expand Down Expand Up @@ -101,9 +103,8 @@ def scale_weights(ctx):
for key, val in mapping.items():
ctx = StateDictTransform(key, val)(ctx)

if transforms:
for transform in transforms:
ctx = transform(ctx)
for transform in transforms:
ctx = transform(ctx)

_params: Dict[str, nn.Parameter] = {}
for name, param in _target.named_parameters():
Expand Down Expand Up @@ -144,9 +145,9 @@ def scale_weights(ctx):

_module.register_buffer(_key, val)

keys = [name for name in list(target_state.keys()) if not name.endswith("_extra_state")]
keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
if len(keys) != 0:
raise RuntimeError(f"Additional keys: {target_state.keys()} in checkpoint but not in model.")
raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")

# TODO: Is this correct?
# for key in target.state_dict():
Expand All @@ -165,7 +166,7 @@ def scale_weights(ctx):


def _default_transform(inp):
return inp.float()
return inp


class StateDictTransform(Generic[F]):
Expand Down Expand Up @@ -324,7 +325,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$")
wildcard_matches = [[] for _ in range(pattern.count("*"))]

for key in keys:
for key in filter(lambda x: x is not None, keys):
match = regex_pattern.match(key)
if match:
for i, group in enumerate(match.groups()):
Expand All @@ -342,7 +343,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
output_array = np.empty(shape, dtype=object)

# Populate the array with the keys, now that we have the correct shape and ordering
for key in keys:
for key in filter(lambda x: x is not None, keys):
match = regex_pattern.match(key)
if match:
# Convert match groups to indices based on their position in wildcard_matches
Expand Down

0 comments on commit 7dd9378

Please sign in to comment.