Skip to content

Commit

Permalink
[bug] fix early return (#5740)
Browse files Browse the repository at this point in the history
* [bug] fix silly bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] add test for prefetch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
botbw and pre-commit-ci[bot] authored May 21, 2024
1 parent 83716e9 commit 13c06d3
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 19 deletions.
5 changes: 3 additions & 2 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,11 @@ def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
maybe_work = None
if not self.is_gathered:
return self.__gather(async_op=async_access)
maybe_work = self.__gather(async_op=async_access)
self.__update_tensors_ptr()
return None
return maybe_work

def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA."""
Expand Down
4 changes: 0 additions & 4 deletions colossalai/zero/gemini/gemini_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch

from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
Expand All @@ -17,9 +16,6 @@ class TrainingPhase(Enum):
BACKWARD = 1


logger = DistributedLogger("gemini_hook")


class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()
Expand Down
8 changes: 4 additions & 4 deletions colossalai/zero/gemini/gemini_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def cuda_margin_mem(self) -> Optional[float]:
return self._mem_stats_collector.cuda_margin_mem
return None

@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy

@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list
Expand All @@ -189,10 +193,6 @@ def compute_idx(self) -> int:
def async_works(self) -> Dict[Chunk, dist.Work]:
return self._async_works

@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy

@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats
Expand Down
10 changes: 9 additions & 1 deletion tests/test_zero/test_gemini/test_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
@parameterize("max_prefetch", [0, 1, 4])
def exam_gpt_fwd_bwd(
placement_config,
keep_gather,
model_name: str,
use_grad_checkpoint: bool = False,
master_weights: bool = True,
max_prefetch: int = 0,
):
init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
Expand All @@ -69,7 +71,13 @@ def exam_gpt_fwd_bwd(
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gather
model = GeminiDDP(
model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights
model,
config_dict,
init_device,
pin_memory=True,
**placement_config,
master_weights=master_weights,
max_prefetch=max_prefetch,
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_zero/test_gemini/test_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("max_prefetch", [0, 1, 4])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
placement_config,
keep_gathered: bool,
model_name: str,
master_weights: bool,
use_grad_checkpoint: bool,
max_prefetch: int,
):
init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
Expand Down Expand Up @@ -81,6 +87,7 @@ def exam_gemini_grad_acc(
pin_memory=True,
enable_gradient_accumulation=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
**placement_config,
)
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_zero/test_gemini/test_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
@parameterize("max_prefetch", [0, 1, 4])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, max_prefetch: int):
set_seed(1912)
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
Expand Down Expand Up @@ -84,6 +85,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
chunk_init_device=init_device,
pin_memory=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
**placement_config,
)

Expand Down
12 changes: 10 additions & 2 deletions tests/test_zero/test_gemini/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("model_name", TEST_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
@parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
@parameterize("max_prefetch", [0, 1, 4])
def exam_model_step(
placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, max_prefetch: int
):
set_seed(42)
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
Expand All @@ -94,7 +97,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = False
model = GeminiDDP(
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
model,
config_dict,
**placement_config,
mixed_precision=mixed_precision,
master_weights=master_weights,
max_prefetch=max_prefetch,
)

optimizer = HybridAdam(model.parameters(), lr=1e-3)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_zero/test_gemini/test_zeroddp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def ignore_the_first_parameter(model: torch.nn.Module):
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
@parameterize("master_weights", [False, True])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
@parameterize("max_prefetch", [0, 1, 4])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool, max_prefetch: int):
set_seed(431)
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))

Expand All @@ -44,7 +45,14 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
model = GeminiDDP(
model,
config_dict,
**placement_config,
pin_memory=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
)
model.train()

zero_dict = model.state_dict(only_rank_0=False)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_zero/test_gemini/test_zerooptim_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
def exam_zero_optim_state_dict(placement_config, keep_gathered):
@parameterize("max_prefetch", [0, 1, 4])
def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch):
set_seed(431)
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
Expand All @@ -35,7 +36,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered):
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered

model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True)
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, max_prefetch=max_prefetch)

optimizer = HybridAdam(model.parameters())
optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32
Expand Down

0 comments on commit 13c06d3

Please sign in to comment.