Skip to content

Commit

Permalink
Removing un-used ModelConfig class (NVIDIA#9389)
Browse files Browse the repository at this point in the history
Co-authored-by: Chen Cui <chcui@nvidia.com>
  • Loading branch information
2 people authored and JesusPaz committed Jun 18, 2024
1 parent 6d65e5c commit 073b06e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 36 deletions.
6 changes: 1 addition & 5 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.optim import Optimizer

from nemo.lightning import get_vocab_size, io
from nemo.lightning.base import ModelConfig
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction

if TYPE_CHECKING:
Expand All @@ -18,7 +17,7 @@


@dataclass
class GPTConfig(TransformerConfig, ModelConfig):
class GPTConfig(TransformerConfig):
# From megatron.core.models.gpt.gpt_model.GPTModel
fp16_lm_cross_entropy: bool = False
parallel_output: bool = True
Expand Down Expand Up @@ -126,9 +125,6 @@ def training_loss_reduction(self) -> MaskedTokenLossReduction:
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
return MaskedTokenLossReduction(validation_step=True)

def copy(self) -> "GPTModel":
return self.__class__(self.config, self.tokenizer)


def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
from megatron.core import parallel_state
Expand Down
33 changes: 2 additions & 31 deletions nemo/lightning/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import gc
import inspect
import os
from pathlib import Path
from typing import Generic, Optional, Type, TypeVar
from typing import Optional

import torch
import torch.distributed
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import Trainer
from torch import nn

from nemo.lightning import io

DEFAULT_NEMO_CACHE_HOME = Path.home() / ".cache" / "nemo"
NEMO_CACHE_HOME = Path(os.getenv("NEMO_HOME", DEFAULT_NEMO_CACHE_HOME))
Expand All @@ -19,33 +17,6 @@
NEMO_MODELS_CACHE = Path(os.getenv("NEMO_MODELS_CACHE", DEFAULT_NEMO_MODELS_CACHE))


ModelT = TypeVar("ModelT", bound=LightningModule)


class ModelConfig(Generic[ModelT], io.IOMixin):
def model_cls(self) -> Type[ModelT]:
raise NotImplementedError("Must be implemented by subclass")

@property
def model_type(self) -> Type[ModelT]:
return self.model_cls()

def init(self, *args, data=None, cpu: bool = False, **kwargs) -> ModelT:
model_cls = self.model_cls()
if data:
kwargs.update(data.model_kwargs())

signature = inspect.signature(model_cls.__init__)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters}

model = model_cls(self, *args, **filtered_kwargs)

if not cpu:
model.cuda(torch.cuda.current_device())

return model


def get_vocab_size(
config,
vocab_size: int,
Expand Down

0 comments on commit 073b06e

Please sign in to comment.