Skip to content

Commit de9fd2b

Browse files
committed
some compile-related updates
ghstack-source-id: 63af8025c184fd5ad34f2f57bf78a37dda2cd33d Pull Request resolved: #443
1 parent 72a1614 commit de9fd2b

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

test_runner.py

+11
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,17 @@ def build_test_list():
168168
"1D compile",
169169
"1d_compile",
170170
),
171+
OverrideDefinitions(
172+
[
173+
[
174+
"--training.compile",
175+
"--activation_checkpoint.mode selective",
176+
"--activation_checkpoint.selective_ac_option op",
177+
],
178+
],
179+
"1D compile with selective op AC",
180+
"1d_compile_sac_op",
181+
),
171182
OverrideDefinitions(
172183
[
173184
[

torchtitan/parallelisms/parallelize_llama.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -441,12 +441,15 @@ def apply_ac(model: nn.Module, ac_config: JobConfig):
441441

442442
def apply_compile(model: nn.Module):
443443
"""Apply torch.compile to each transformer block."""
444+
445+
# the following flag can be used to to accelarate per-block compilation
446+
# TODO(bdhirsh): turning it off because it's currently not working with 2D
447+
# TODO(anijain): remove it after it's enabled in pytorch by default
448+
# torch._dynamo.config.inline_inbuilt_nn_modules = True
449+
444450
for layer_id, transformer_block in model.layers.named_children():
445-
# TODO: dynamic shape have some issues so we turn it off for now.
446-
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
447-
# compile time.
448-
# torch._dynamo.config.inline_inbuilt_nn_modules = True
449-
transformer_block = torch.compile(transformer_block, dynamic=False)
451+
# turn on per-transformer block compile after AC wrapping and before FSDP
452+
transformer_block = torch.compile(transformer_block, fullgraph=True)
450453
model.layers.register_module(layer_id, transformer_block)
451454

452455
logger.info("Compiled each TransformerBlock with torch.compile")

0 commit comments

Comments
 (0)