From fa93a0564a6d63c13a61e21a403dfb2eea790f74 Mon Sep 17 00:00:00 2001 From: Helen Ngo Date: Thu, 9 Jan 2025 14:50:03 -0800 Subject: [PATCH] ADLR/megatron-lm!2429 - Inference CUDA graphs (MCore version) Co-authored-by: Jimmy Zhang --- .../core/extensions/transformer_engine.py | 4 ++ megatron/core/inference_params.py | 42 ++++++++++++++++++- megatron/core/models/gpt/gpt_model.py | 20 +++++++-- megatron/core/tensor_parallel/random.py | 42 +++++++++++++++---- megatron/core/transformer/attention.py | 23 ++++++---- megatron/core/transformer/cuda_graphs.py | 18 +++++--- .../core/transformer/transformer_block.py | 6 +++ .../core/transformer/transformer_config.py | 12 ++++++ .../core/transformer/transformer_layer.py | 20 ++++++--- .../inference/text_generation/generation.py | 20 +++++++-- megatron/training/arguments.py | 7 ++++ megatron/training/initialize.py | 6 +-- 12 files changed, 184 insertions(+), 36 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 5884109cae..e74f47803e 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1135,6 +1135,10 @@ class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): """Wraps TransformerEngine's CudaRNGStatesTracker so that it is interchangeable with Megatron's RNG tracker""" + def __init__(self): + super().__init__() + self.reset() + def is_initialized(self): """Checks if the internal RNG state has been set wirth set_states().""" return self._is_initialized diff --git a/megatron/core/inference_params.py b/megatron/core/inference_params.py index 0db49e3115..c8b05d1b57 100644 --- a/megatron/core/inference_params.py +++ b/megatron/core/inference_params.py @@ -6,6 +6,7 @@ class InferenceParams: def __init__(self, max_batch_size, max_sequence_length): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size + self.current_batch_size = max_batch_size # Required for bookkeeping variable-sized batches self.sequence_len_offset = 0 self.batch_size_offset = 0 self.key_value_memory_dict = {} @@ -28,4 +29,43 @@ def swap_key_value_dict(self, batch_idx): ) def __str__(self): - return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})" + return ( + f"InferenceParams(max_seq_len = {self.max_sequence_length}, " + f"max_batch_size = {self.max_batch_size}, " + f"current_batch_size = {self.current_batch_size}, " + f"sequence_len_offset = {self.sequence_len_offset}, " + f"batch_size_offset = {self.batch_size_offset}, " + f"key_value_memory_dict = {self.key_value_memory_dict.keys()})" + ) + + def __eq__(self, other): + + if not isinstance(other, InferenceParams): + return False + + # Check all attributes match + basic_attrs = [ + 'max_sequence_length', + 'max_batch_size', + 'current_batch_size', + 'sequence_len_offset', + 'batch_size_offset', + ] + + if not all(hasattr(other, attr) for attr in basic_attrs): + return False + + # Check dictionary keys match; i.e. the same number of layers are cached + if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys(): + return False + + # Check each tensor tuple in the dictionary + for key in self.key_value_memory_dict: + self_tensors = self.key_value_memory_dict[key] + other_tensors = other.key_value_memory_dict[key] + + # Compare each key, value tensor in the tuple + for self_tensor, other_tensor in zip(self_tensors, other_tensors): + if not self_tensor.shape == other_tensor.shape: + return False + return True diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index be8cdce111..62a897f3ad 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -3,6 +3,7 @@ from collections import OrderedDict from typing import Dict, Literal, Optional +import torch from torch import Tensor from megatron.core import InferenceParams, tensor_parallel @@ -121,6 +122,9 @@ def __init__( use_cpu_initialization=self.config.use_cpu_initialization, ) + # Cache for RoPE tensors which do not change between iterations. + self.rotary_pos_emb_cache = {} + # Transformer. self.decoder = TransformerBlock( config=self.config, @@ -224,10 +228,11 @@ def forward( rotary_pos_cos = None rotary_pos_sin = None if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention: - if not self.training and self.config.flash_decode: + if not self.training and self.config.flash_decode and inference_params: # Flash decoding uses precomputed cos and sin for RoPE - rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cos_sin( - inference_params.max_sequence_length + rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( + inference_params.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), ) else: rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( @@ -238,6 +243,14 @@ def forward( packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) + if (self.config.enable_cuda_graph or self.config.flash_decode) and inference_params: + sequence_len_offset = torch.tensor( + [inference_params.sequence_len_offset] * inference_params.current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None # Run decoder. hidden_states = self.decoder( @@ -248,6 +261,7 @@ def forward( rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, **(extra_block_kwargs or {}), ) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index f3d4ab772f..31bec68aa7 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -164,35 +164,59 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): _CUDA_RNG_STATE_TRACKER_INITIALIZED = False -def initialize_rng_tracker(use_te_rng_tracker: bool = False): +def initialize_rng_tracker(use_te_rng_tracker: bool = False, inference_rng_tracker: bool = False): """Create the RNG tracker. 'use_te_rng_tracker' determines whether to use Megatron or TransformerEngine's implementation. In particular, TransformerEngine's implementation is cudagraphable and supports FP8. """ - global _CUDA_RNG_STATE_TRACKER global _CUDA_RNG_STATE_TRACKER_INITIALIZED if _CUDA_RNG_STATE_TRACKER_INITIALIZED: return + # Get the base tracker class + base_tracker = None if use_te_rng_tracker: if not is_te_min_version("1.5.0"): raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5") from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker - _CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker() + base_tracker = TECudaRNGStatesTracker else: - _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + base_tracker = CudaRNGStatesTracker + + if inference_rng_tracker: + + class InferenceCudaRNGStatesTracker(base_tracker): + """RNG tracker for inference.""" + + def add(self, name, seed): + """Mirrors the interface from the training RNG tracker.""" + pass + + def set_states(self, states): + """Mirrors the interface from the training RNG tracker.""" + pass + + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Mirrors the interface from the training RNG tracker.""" + return contextlib.nullcontext() + + tracker_class = InferenceCudaRNGStatesTracker + else: + tracker_class = base_tracker + + _CUDA_RNG_STATE_TRACKER = tracker_class() _CUDA_RNG_STATE_TRACKER_INITIALIZED = True -def get_cuda_rng_tracker(use_te_rng_tracker=False): +def get_cuda_rng_tracker(use_te_rng_tracker=False, inference_rng_tracker=False): """Get cuda rng tracker.""" - initialize_rng_tracker(use_te_rng_tracker) + initialize_rng_tracker(use_te_rng_tracker, inference_rng_tracker) return _CUDA_RNG_STATE_TRACKER -def model_parallel_cuda_manual_seed(seed): +def model_parallel_cuda_manual_seed(seed, te_rng_tracker=False, inference_rng_tracker=False): """Initialize model parallel cuda seed. This function should be called after the model parallel is @@ -216,7 +240,7 @@ def model_parallel_cuda_manual_seed(seed): # Data parallel gets the original seed. data_parallel_seed = seed - initialize_rng_tracker() + initialize_rng_tracker(te_rng_tracker, inference_rng_tracker) _CUDA_RNG_STATE_TRACKER.reset() # Set the default state. torch.cuda.manual_seed(data_parallel_seed) @@ -241,6 +265,7 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, distribute_saved_activations, *args): + """Forward pass.""" ctx.run_function = run_function ctx.distribute_saved_activations = distribute_saved_activations @@ -267,6 +292,7 @@ def forward(ctx, run_function, distribute_saved_activations, *args): @staticmethod def backward(ctx, *args): + """Backward pass.""" if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), " diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 713a6887d9..f62bf0d4e8 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -169,14 +169,14 @@ def custom_forward(*inputs): return hidden_states - def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): + def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): """Allocate memory to store kv cache during inference.""" return torch.empty( inference_max_sequence_length, batch_size, self.num_query_groups_per_partition, - dim, + self.hidden_size_per_attention_head, dtype=dtype, device=torch.cuda.current_device(), ) @@ -190,6 +190,7 @@ def _adjust_key_value_for_inference( rotary_pos_emb: Tensor, rotary_pos_cos: Tensor = None, rotary_pos_sin: Tensor = None, + sequence_len_offset=None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Saves the generated key and value tensors to the end of the buffers in inference_params. @@ -210,10 +211,10 @@ def _adjust_key_value_for_inference( inf_max_seq_length = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( - inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype + inf_max_seq_length, inf_max_batch_size, key.dtype ) inference_value_memory = self._allocate_memory( - inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype + inf_max_seq_length, inf_max_batch_size, value.dtype ) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, @@ -246,7 +247,7 @@ def _adjust_key_value_for_inference( rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end] rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end] rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end] - else: + else: # Prefill rotary_pos_cos_q = rotary_pos_cos[:sequence_end] rotary_pos_sin_q = rotary_pos_sin[:sequence_end] rotary_pos_cos_k = rotary_pos_cos[:sequence_end] @@ -341,6 +342,7 @@ def forward( rotary_pos_sin=None, attention_bias=None, packed_seq_params=None, + sequence_len_offset=None, ): """ Perform a forward pass through the attention module. @@ -380,7 +382,7 @@ def forward( self.layer_number ] output = self.flash_decoding( - sequence_len_offset=inference_params.sequence_len_offset, + sequence_len_offset=sequence_len_offset, query_layer=query, key_layer=key, value_layer=value, @@ -395,7 +397,14 @@ def forward( return output, bias query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin + inference_params, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, ) if packed_seq_params is not None: diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 20257abc28..2b477d2cd2 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -31,6 +31,8 @@ _IS_GRAPH_CAPTURING = False +logger = logging.getLogger(__name__) + def is_graph_capturing(): """Query if currently capturing.""" @@ -53,7 +55,9 @@ def _set_capture_end(): def _check_supported_type(arg): """Check if arg is a supported type for cudagraph input/outputs.""" - _SUPPORTED_TYPES = {torch.Tensor, type(None), bool, int, str, float} + from megatron.core import InferenceParams # guard against circular import + + _SUPPORTED_TYPES = {torch.Tensor, type(None), bool, int, str, float, InferenceParams} assert type(arg) in _SUPPORTED_TYPES or is_dataclass( arg ), f"Cudagraphs recieved an arg of type {type(arg)} which is not supported." @@ -153,7 +157,6 @@ def create_cudagraphs(cls): vpp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() if vpp_size is None or vpp_size == 1: bwd_mempool = fwd_mempool - if optimize_transformer_layer_graph_buffers: if graph_type == 'fwd': args, kwargs = g[3:] @@ -550,6 +553,7 @@ def record_graph_capture(self, args, kwargs): 'create_cudagraphs()` is called. Subsequent fwd passes will replay the cudagraph. """ if not self.fwd_graph_recorded: + logger.debug(f"Recording forward graph creation...") _CudagraphGlobalRecord.record_fwd_graph(self, args, kwargs) self.fwd_graph_recorded = True @@ -598,10 +602,11 @@ def forward(self, is_first_microbatch, args, kwargs): return out def matches_graph_inputs(self, args, kwargs): - """Check the the passed args, kwargs match with the arg, kwargs + """Check that the passed args, kwargs match with the arg, kwargs the graph was created with.""" def check(val, ref): + _check_supported_type(val) _check_supported_type(ref) @@ -758,9 +763,12 @@ def __call__(self, megatron_module, args, kwargs): if self.training and torch.is_grad_enabled(): runner = _CudaGraphRunner(megatron_module, len(self.cudagraph_runners)) self.cudagraph_runners.append(runner) + elif 'inference_params' in kwargs.keys() and kwargs['inference_params']: + # Instantiate the cudagraphed version of the module in inference mode + runner = _CudaGraphRunner(megatron_module, len(self.cudagraph_runners)) + runner.eval() + self.cudagraph_runners.append(runner) else: - # No cudagraphs were found in inference mode, so fallback to eager since - # tensor.requires_grad is needed to correctly trace the backward graph. return super(MegatronModule, megatron_module).__call__(*args, **kwargs) # Trigger Mcore DDP pre-forward hooks diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index d40476d27b..3114b859f3 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -404,6 +404,7 @@ def forward( attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + sequence_len_offset: Tensor = None, ): """ Perform the forward pass through the transformer block. @@ -436,6 +437,10 @@ def forward( # See set_input_tensor() hidden_states = self.input_tensor + # Update the inference parameters with the current batch size in case it is variable + if inference_params and not self.training: + inference_params.current_batch_size = hidden_states.size(1) + # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' @@ -512,6 +517,7 @@ def forward( attention_bias=attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, ) else: # CUDA graph replay for layer `l_no` and microbatch diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 777cc1e993..e48022ee5f 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -385,6 +385,12 @@ class TransformerConfig(ModelParallelConfig): flash_decode: bool = False """ Use the optimized flash decoding kernel during inference. """ + use_te_rng_tracker: bool = False + """ Whether to use the TE or MCore version of the RNG tracker. """ + + inference_rng_tracker: bool = False + """ Whether we should instantiate a separate RNG tracker for inference. """ + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -605,6 +611,12 @@ def __post_init__(self): if self.flash_decode and self.fp8: raise ValueError("FP8 inference is currently not support with flash decoding.") + if self.enable_cuda_graph: + if self.cpu_offloading: + raise ValueError("CUDA graphs not supported with CPU offloading.") + if self.recompute_granularity: + raise ValueError("CUDA graphs not supported with activation recomputation.") + if self.moe_token_dispatcher_type in ['allgather', 'alltoall_seq']: if self.variable_seq_lengths is True: raise ValueError( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 0e7eabbff5..5b96f91cdb 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -93,10 +93,12 @@ def __init__( ): super().__init__(config=config) - if config.enable_cuda_graph and self.training: - assert ( - not config.cpu_offloading and config.recompute_granularity is None - ), "Cudagraphs not supported" + if config.enable_cuda_graph: + if not self.training: + # Cudagraphs for inference are only enabled with the flash decoding kernel + assert ( + self.config.flash_decode + ), "--flash-decode is required to use CUDA graphs during inference" self.cudagraph_manager = CudaGraphManager() self.submodules_config = submodules @@ -263,6 +265,7 @@ def forward( attention_bias=None, inference_params=None, packed_seq_params=None, + sequence_len_offset=None, ): """ Perform a forward pass through the transformer layer. @@ -304,6 +307,7 @@ def forward( rotary_pos_sin=rotary_pos_sin, attention_bias=attention_bias, packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, ) # TODO: could we move `bias_dropout_add_exec_handler` itself @@ -392,6 +396,12 @@ def sharded_state_dict( return sharded_state_dict def __call__(self, *args, **kwargs): - if hasattr(self, 'cudagraph_manager'): + # Training and validation mode CUDA graphs + if hasattr(self, 'cudagraph_manager') and kwargs.get('inference_params') is None: + return self.cudagraph_manager(self, args, kwargs) + # Inference mode. CUDA graphs are used in the decode phase only, when attn mask is None + elif not self.training and ( + hasattr(self, 'cudagraph_manager') and kwargs['attention_mask'] is None + ): return self.cudagraph_manager(self, args, kwargs) return super(MegatronModule, self).__call__(*args, **kwargs) diff --git a/megatron/inference/text_generation/generation.py b/megatron/inference/text_generation/generation.py index 13e53b3c6a..9007402899 100644 --- a/megatron/inference/text_generation/generation.py +++ b/megatron/inference/text_generation/generation.py @@ -8,6 +8,7 @@ from megatron.training import get_args, get_tokenizer from megatron.core import mpu from megatron.training.utils import get_ltor_masks_and_position_ids +from megatron.core.transformer.cuda_graphs import create_cudagraphs from .communication import ( copy_from_last_to_first_pipeline_stage, broadcast_from_last_pipeline_stage, @@ -202,7 +203,7 @@ def generate_tokens_probs_and_return_on_first_stage( device=torch.cuda.current_device()) # ============= - # Run infernece + # Run inference # ============= with torch.no_grad(): @@ -211,15 +212,22 @@ def generate_tokens_probs_and_return_on_first_stage( prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): + prefill = context_length == min_prompt_length + # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] + + # Do not pass a variable-shape attention mask in the decode phase. attention_mask2use = attention_mask[ - ..., prev_context_length:context_length, :context_length] + ..., prev_context_length:context_length, :context_length] if prefill else None # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens2use, positions2use, attention_mask2use) + if args.enable_cuda_graph: + create_cudagraphs() + if mpu.is_pipeline_last_stage(): if prevent_newline_after_colon: logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":" @@ -343,7 +351,7 @@ def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, device=torch.cuda.current_device()).unsqueeze(1) scores_size_tensor, tokens_size_tensor = None, None # ============= - # Run infernece + # Run inference # ============= with torch.no_grad(): tokens = tokens.repeat(beam_size, 1) @@ -351,11 +359,15 @@ def beam_search_and_return_on_first_stage(model, forward_step, tokens, lengths, prev_context_length = 0 for context_length in range(prompt_length, final_sequence_length): + prefill = context_length == prompt_length + # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] + + # Do not pass a variable-shape attention mask in the decode phase. attention_mask2use = attention_mask[ - ..., prev_context_length:context_length, :context_length] + ..., prev_context_length:context_length, :context_length] if not prefill else None # logits will be meanigful only in the last pipeline stage. logits = forward_step(tokens2use, positions2use, attention_mask2use) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index ff430957d1..d7a3e52a32 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -855,6 +855,11 @@ def _add_transformer_engine_args(parser): group.add_argument('--fp8-param-gather', action='store_true', help='Keep the compute param in fp8 (do not use any other intermediate ' 'dtype) and perform the param all-gather in fp8.') + group.add_argument('--te-rng-tracker', action='store_true', default=False, + help='Use the Transformer Engine version of the random number generator. ' + 'Required for CUDA graphs support.') + group.add_argument('--inference-rng-tracker', action='store_true', default=False, + help='Use a random number generator configured for inference.') return parser def _add_inference_args(parser): @@ -881,6 +886,8 @@ def _add_inference_args(parser): 'Bert embedder.') group.add_argument('--flash-decode', default=False, action="store_true", help='Whether to use the flash decoding kernel.') + group.add_argument('--enable-cuda-graph', default=False, action="store_true", + help='Use CUDA graph capture and replay.') group.add_argument('--inference-max-seq-length', type=int, default=2560, help='Maximum sequence length allocated for prefill during inference.', dest='inference_max_seq_length') diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index cb05731977..f432ceac23 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -106,7 +106,7 @@ def finish_mpu_init(): # Random seeds for reproducibility. if args.rank == 0: print("> setting random seeds to {} ...".format(args.seed)) - _set_random_seed(args.seed, args.data_parallel_random_init) + _set_random_seed(args.seed, args.data_parallel_random_init, args.te_rng_tracker, args.inference_rng_tracker) if skip_mpu_initialization: return None @@ -336,7 +336,7 @@ def _init_autoresume(): torch.distributed.barrier() -def _set_random_seed(seed_, data_parallel_random_init=False): +def _set_random_seed(seed_, data_parallel_random_init=False, te_rng_tracker=False, inference_rng_tracker=False): """Set random seed for reproducability.""" if seed_ is not None and seed_ > 0: # Ensure that different pipeline MP stages get different seeds. @@ -348,7 +348,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: - tensor_parallel.model_parallel_cuda_manual_seed(seed) + tensor_parallel.model_parallel_cuda_manual_seed(seed, te_rng_tracker, inference_rng_tracker) else: raise ValueError("Seed ({}) should be a positive integer.".format(seed_))