Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add selective activation checkpointing #97

Merged
merged 4 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ def init_args_from_command_line(
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="whether to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.save_tb_folder",
type=str,
Expand Down Expand Up @@ -215,4 +215,9 @@ def init_args_from_command_line(
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--training.enable_selective_ac",
action="store_false",
help="whether to enable selective activation checkpointing",
)
return parser.parse_args(args_list)
59 changes: 51 additions & 8 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# llama model, i.e. activation checkpoint, etc.

import logging
from collections import defaultdict

import torch
from torch.distributed._tensor import (
Expand Down Expand Up @@ -33,7 +34,6 @@
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,12 +67,52 @@ def partition_fn(name, module, device_mesh):
)


# for selective AC
no_recompute_list = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not vital for now, but longer term I think this policy should be exposed at a higher level if we expect to have other models being added and/or expect people to customize this policy list . i.e. if have parallelize_gpt or similar, then it's awkward to pull the recompute list from parallelize_llama.

torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default,
}

# Uses PTD FSDP AC wrapper
# TODO: why is config needed here?
def checkpoint_wrapper(module, job_config: JobConfig):
return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)
def checkpoint_wrapper(module, enable_selective_ac=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: Prefer to not make enable_selective_ac a default arg if we always pass it explicitly.

if enable_selective_ac:
from torch.utils.checkpoint import (
_pt2_selective_checkpoint_context_fn_gen,
checkpoint,
)

def _get_custom_policy(meta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, does this policy also run in eager mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it only works in eager mode; with compiler we got:

[rank0]:[rank0]:     assert torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint, \
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: AssertionError: Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM. Wanchao told me we need to set torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = True for it to work. Now SAC works with compile.

def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
return func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)

return _custom_policy

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
context_fn=selective_checkpointing_context_fn,
use_reentrant=False,
preserve_rng_state=False,
)
else:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
preserve_rng_state=False,
)


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
Expand Down Expand Up @@ -168,10 +208,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# apply selective AC
transformer_block = checkpoint_wrapper(
transformer_block, job_config.training.enable_selective_ac
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
enable_selective_ac = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm you are putting it in the training section but the cmd arg parser is not on training, I think this should not work as expected.. (if it is, then we need to figuring out why toml parsing works wrongly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our setting, cmd arg parser is a backup way of providing args. The toml file has it in the training section, in the code traning.enable_selective_ac is used, so no problem there -- the cmd arg metrics.enable_selective_ac is parsed but not used.

Loading