Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa committed Oct 22, 2024
1 parent 8d1d6aa commit c040e20
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 25 deletions.
16 changes: 8 additions & 8 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
class SquadDataModuleWithPthDataloader(llm.SquadDataModule):
def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
batch_size=self.micro_batch_size,
**kwargs,
)
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=dataset.collate_fn,
batch_size=self.micro_batch_size,
**kwargs,
)


def squad(tokenizer) -> pl.LightningDataModule:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from nemo.collections.llm.gpt.data import (
DollyDataModule,
FineTuningDataModule,
HfDatasetDataModule,
MockDataModule,
PreTrainingDataModule,
SquadDataModule,
HfDatasetDataModule,
)
from nemo.collections.llm.gpt.data.api import dolly, mock, squad
from nemo.collections.llm.gpt.model import (
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.gpt.data.squad import SquadDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule

__all__ = [
"FineTuningDataModule",
"SquadDataModule",
"DollyDataModule",
"MockDataModule",
"PreTrainingDataModule",
"HfDatasetDataModule"
"HfDatasetDataModule",
]
23 changes: 10 additions & 13 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ class HfDatasetDataModule(pl.LightningDataModule):
def __init__(
self,
dataset,
num_workers = 2,
pin_memory = True,
persistent_workers = True,
micro_batch_size = 2,
global_batch_size = 2,
pad_token_id = 0,
use_mcore_sampler = False,
mcore_dataloader_type = 'cyclic',
num_workers=2,
pin_memory=True,
persistent_workers=True,
micro_batch_size=2,
global_batch_size=2,
pad_token_id=0,
use_mcore_sampler=False,
mcore_dataloader_type='cyclic',
) -> None:
super().__init__()
assert pad_token_id is not None
Expand Down Expand Up @@ -56,10 +56,8 @@ def extract_key_from_dicts(batch, key):

def pad_within_micro(batch, pad_token_id):
max_len = max(map(len, batch))
return [
item + [pad_token_id] * (max_len - len(item))
for item in batch
]
return [item + [pad_token_id] * (max_len - len(item)) for item in batch]

return {
key: batchify(
torch.LongTensor(
Expand Down Expand Up @@ -103,4 +101,3 @@ def train_dataloader(self, collate_fn=None):
rank=rank,
world_size=world_size,
)

Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def configure_model(self):
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype='auto')
else:
from transformers import AutoConfig

config = AutoConfig.from_pretained(self.model_name)
self.model = AutoModelForCausalLM.from_config(config)
self.model.train()
Expand Down
4 changes: 3 additions & 1 deletion nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def __init__(
def connect(self, model: pl.LightningModule) -> None:
super().connect(model)

assert not 'HfAutoModelForCausalLM' in type(model).__name__, "Cannot use HfAutoModelForCausalLM with MegatronParallel"
assert (
not 'HfAutoModelForCausalLM' in type(model).__name__
), "Cannot use HfAutoModelForCausalLM with MegatronParallel"

_maybe_mcore_config = _strategy_lib.set_model_parallel_attributes(model, self.parallelism)
if _maybe_mcore_config:
Expand Down

0 comments on commit c040e20

Please sign in to comment.