Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gemini] fixes for benchmarking #5847

Merged
merged 8 commits into from
Jun 26, 2024
Merged
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ def __init__(
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"
if placement_policy == "auto" and enable_async_reduce:
if enable_async_reduce and not pin_memory:
logging.warning(
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
)
pin_memory = True
self.gemini_config = dict(
Expand Down
12 changes: 7 additions & 5 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,9 @@ def reduce(self, async_op: bool = False):
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
)

input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
self.grad_reduce_work = dist.reduce_scatter(
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
assert self.cuda_global_chunk.is_contiguous()
self.grad_reduce_work = dist.reduce_scatter_tensor(
self.cuda_shard, self.cuda_global_chunk, group=self.torch_pg, async_op=async_op
)

if self.extra_dp_group is not None:
Expand Down Expand Up @@ -520,8 +520,10 @@ def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
assert self.cuda_shard is not None

alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
assert self.cuda_global_chunk.is_contiguous()
work = dist.all_gather_into_tensor(
self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op
)

self.cuda_shard = None
self.is_gathered = True
Expand Down
4 changes: 2 additions & 2 deletions colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def release_chunk(self, chunk: Chunk) -> None:
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)

def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False, async_move=False) -> None:
"""Move the shard of the chunk to the target device."""
if not chunk.can_move or chunk.device_type == device.type:
return
self.__sub_memory_usage(chunk.memory_usage)
chunk.shard_move(device, force_copy)
chunk.shard_move(device, force_copy, non_blocking=async_move)
self.__add_memory_usage(chunk.memory_usage)

def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
Expand Down
11 changes: 8 additions & 3 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def grad_handle(
p: nn.Parameter,
async_reduce_stream: Optional[torch.cuda.Stream] = None,
):
async_reduce_scatter = async_reduce_stream is not None
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
Expand Down Expand Up @@ -426,7 +427,7 @@ def grad_handle(
async_reduce_stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(async_reduce_stream):
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter)
if reduced:
grad_chunk.wait_async_reduce()
if not chunk_manager.reuse_fp16_chunk:
Expand All @@ -447,9 +448,13 @@ def grad_handle(
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
chunk_manager.move_chunk(
grad_chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
chunk_manager.move_chunk(
chunk, grads_device[p], force_copy=True, async_move=async_reduce_scatter
)
return empty_grad

def zero_grad(self, set_to_none: bool = False) -> None:
Expand Down
11 changes: 8 additions & 3 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,13 @@ def empty_init():
init_kwargs["empty_init"] = False

with init_ctx:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, **init_kwargs)

model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
**init_kwargs,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
botbw marked this conversation as resolved.
Show resolved Hide resolved
)
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if config.model_type == "chatglm":
Expand Down Expand Up @@ -261,7 +266,7 @@ def empty_init():

with get_profile_context(
args.profile,
1,
args.ignore_steps,
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
Expand Down
Loading