Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

address TODOs as 2D recompiles is fixed #508

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ def apply_tp(
if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

# TODO: remove cache_size_limit adjustment after 2D compile is fixed
torch._dynamo.config.cache_size_limit = 10000

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

Expand Down Expand Up @@ -280,18 +277,15 @@ def apply_ac(model: nn.Module, ac_config):


def apply_compile(model: nn.Module):
"""Apply torch.compile to each transformer block."""

# the following flag can be used to to accelarate per-TransformerBlock compilation
# TODO(bdhirsh): turning it off because it's currently not working with 2D
# TODO(anijain): remove it after it's enabled in pytorch by default
# torch._dynamo.config.inline_inbuilt_nn_modules = True

"""
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
repeated structure. Alternatively one can compile the whole model (after applying DP).
"""
for layer_id, transformer_block in model.layers.named_children():
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiled each TransformerBlock with torch.compile")
logger.info("Compiling each TransformerBlock with torch.compile")
return model


Expand Down
Loading