Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 24, 2024
1 parent 2b06122 commit a1b31fa
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _llama_trace_input(job_config, model_config, device="meta"):


def _mixed_precision_dtype(
job_config: JobConfig, default: torch.dtype = torch.float32
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
Expand Down Expand Up @@ -213,7 +213,7 @@ def pipeline_llama_manual(
# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
mp_dtype = _mixed_precision_dtype(job_config)
mp_dtype = _mixed_precision_dtype(job_config, parallel_dims)
batch_size = job_config.training.batch_size
local_seq_len = int(job_config.training.seq_len // parallel_dims.tp)
layers_io_shape = (batch_size, local_seq_len, model_config.dim)
Expand Down Expand Up @@ -260,7 +260,7 @@ def pipeline_llama_tracer(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)

if _mixed_precision_dtype(job_config) == torch.bfloat16:
if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
Expand Down

0 comments on commit a1b31fa

Please sign in to comment.