Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Aug 13, 2024
1 parent 9140e4b commit 1d92ddb
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions nemo/lightning/pytorch/plugins/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from contextlib import contextmanager
from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union

import pytorch_lightning as pl
import torch
from pytorch_lightning.plugins.precision import MixedPrecision
from pytorch_lightning.plugins.precision import Precision
from torch.nn import Module
from torch.optim import Optimizer

Expand All @@ -27,7 +28,7 @@


@dataclass
class FP8Config:
class FP8Config():
fp8 = False
fp8_margin = None
fp8_interval = None
Expand All @@ -39,14 +40,15 @@ class FP8Config:


@dataclass
class DtypeConfig:
fp32 = False
fp16 = False
bf16 = False
params_dtype = None
pipeline_dtype = None
autocast_dtype = None
autocast_enabled = False
class DtypeConfig():
fp32: bool = False
fp16: bool = False
bf16: bool = False
params_dtype: torch.dtype = None
pipeline_dtype: torch.dtype = None
autocast_dtype: torch.dtype = None
autocast_enabled: bool = False
grad_reduce_in_fp32: bool = False


def make_default_dtype_config_from_precision(precision: str):
Expand All @@ -66,12 +68,13 @@ def make_default_dtype_config_from_precision(precision: str):
fp32=is_fp32(precision),
fp16=is_fp16(precision),
bf16=is_bf16(precision),
params_dtype=torch.float32,
params_dtype=torch.bfloat16,
pipeline_dtype=dtype,
autocast_dtype=dtype,
autocast_enabled=False, # for now
autocast_enabled=False, # for now
grad_reduce_in_fp32=True,
)
assert not (
assert (
config.fp32 or config.fp16 or config.bf16
), f"Error: Expected precision to be FP32, BF16 or FP16, {precision}"
return config
Expand All @@ -85,7 +88,7 @@ def generate_optim_config(self):
def generate_model_config(self):
return {}

class MegatronCustomPrecision(MixedPrecision, McoreConfigsFromPlugin):
class MegatronCustomPrecision(Precision, McoreConfigsFromPlugin):
def __init__(
self,
dtype_config: DtypeConfig,
Expand All @@ -99,7 +102,7 @@ def __init__(
scaler = None
if self.dtype_config.fp16:
scaler = GradScaler(init_scale=2**32, growth_interval=1000, hysteresis=2)
super().__init__(precision, device, scaler)
super().__init__()#precision, device, scaler)
self.amp_O2 = amp_O2

def connect(
Expand Down Expand Up @@ -226,7 +229,7 @@ def __init__(
amp_O2: bool = False,
device: str = "cuda",
) -> None:
super().__init__(make_default_config_from_precision(precision), None, device, amp2_O2)
super().__init__(make_default_dtype_config_from_precision(precision), None, device, amp_O2)


__all__ = ["MegatronMixedPrecision", "MegatronCustomPrecision"]

0 comments on commit 1d92ddb

Please sign in to comment.