Skip to content

Commit 2f23216

Browse files
committed
some compile-related improvements
ghstack-source-id: 7c4a65c26a8f573222f0a14448ba8258ed893028 Pull Request resolved: #443
1 parent 3fca883 commit 2f23216

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

test_runner.py

+9
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def build_test_list():
153153
"1D compile",
154154
"1d_compile",
155155
),
156+
OverrideDefinitions(
157+
[
158+
[
159+
"--training.compile --model.norm_type=rmsnorm --selective_ac_option=op",
160+
],
161+
],
162+
"1D compile with selective op AC",
163+
"1d_compile_sac_op",
164+
),
156165
OverrideDefinitions(
157166
[
158167
[

torchtitan/parallelisms/parallelize_llama.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -432,22 +432,14 @@ def apply_compile(model, job_config: JobConfig):
432432
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
433433
)
434434

435+
# NOTE(anijain): enable the following flag to accelarate compilation
436+
torch._dynamo.config.inline_inbuilt_nn_modules = True
437+
435438
for layer_id, transformer_block in model.layers.named_children():
436439
# turn on per-transformer block compile after AC wrapping and before FSDP
437-
# TODO: dynamic shape have some issues so we turn it off for now.
438-
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
439-
# compile time.
440-
# torch._dynamo.config.inline_inbuilt_nn_modules = True
441-
transformer_block = torch.compile(transformer_block, dynamic=False)
440+
transformer_block = torch.compile(transformer_block, fullgraph=True)
442441
model.layers.register_module(layer_id, transformer_block)
443442

444-
ac_config = job_config.activation_checkpoint
445-
if ac_config.mode == "selective" and ac_config.selective_ac_option == "op":
446-
# some temp flags for torch.compile enablement + SAC
447-
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
448-
True
449-
)
450-
451443
logger.info("Compiled each TransformerBlock with torch.compile")
452444
return model
453445

0 commit comments

Comments
 (0)