Skip to content

Commit

Permalink
use microbatch calculator from mcore
Browse files Browse the repository at this point in the history
Signed-off-by: ashors1 <ashors@nvidia.com>
  • Loading branch information
ashors1 committed Jul 8, 2024
1 parent f915988 commit 12c1cf6
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 394 deletions.
6 changes: 2 additions & 4 deletions nemo/collections/llm/gpt/data/pre_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
except ModuleNotFoundError:
from nemo.lightning.apex_utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR

consumed_samples = state_dict['consumed_samples']
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples
Expand Down
13 changes: 7 additions & 6 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@

import numpy as np
import torch
from megatron.core.num_microbatches_calculator import (
ConstantNumMicroBatchesCalculator,
init_num_microbatches_calculator,
)

from nemo.utils import AppState, logging

try:
from apex.transformer.log_util import set_logging_level
from apex.transformer.microbatches import ConstantNumMicroBatches
from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):

from nemo.lightning.apex_utils import setup_microbatch_calculator
except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

Expand Down Expand Up @@ -138,10 +139,10 @@ def initialize_model_parallel_for_nemo(

if global_batch_size and micro_batch_size is not None:
# TODO: add rampup_batch_size here when we have it implemented
from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR

if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
setup_microbatch_calculator(
init_num_microbatches_calculator(
rank=global_rank,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
Expand Down
Loading

0 comments on commit 12c1cf6

Please sign in to comment.