Skip to content

Commit

Permalink
remove all to all
Browse files Browse the repository at this point in the history
  • Loading branch information
GuangyaoZhang committed Jul 17, 2024
1 parent 5a310b9 commit 6a20f07
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 176 deletions.
4 changes: 2 additions & 2 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
return ret.to(ret_type)


def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None:
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
Expand Down Expand Up @@ -167,7 +167,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["fp8_scale"]


def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
r"""
This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
Expand Down
153 changes: 21 additions & 132 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def backward(ctx, grad_output):
if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None


def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
Expand Down Expand Up @@ -261,7 +261,7 @@ def backward(ctx, grad_output):

dist.reduce_scatter(output, grad_list, group=process_group)

return output, None, None, None
return output, None, None


class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
Expand Down Expand Up @@ -729,7 +729,7 @@ def backward(ctx, grad_output):
grad_output = grad_output * ctx.grad_scale

# to_cast.append(grad_output.cpu().detach().numpy())
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, "e4m3"), None, None, None, None
return _gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication), None, None, None, None


class _ReduceForward(torch.autograd.Function):
Expand Down Expand Up @@ -786,7 +786,7 @@ def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=
ctx.dim = dim
ctx.grad_scale = grad_scale

return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
return _gather(input_, dim, process_group, fp8_communication=fp8_communication)

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -806,67 +806,26 @@ class _AllToAll(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication):
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = input_.shape

# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim)
else:
return _all_to_all(
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim)

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz, _, _ = grad_output.shape

if bsz == 1:
return_grad = _all_to_all_single(
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
else:
return_grad = _all_to_all(
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)

return (return_grad, None, None, None, None)
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
return (return_grad, None, None, None)


class HookParameter(torch.autograd.Function):
Expand Down Expand Up @@ -924,41 +883,20 @@ def _split(input_, dim=-1, process_group=None):
return output


def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e4m3"):
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_

# all gather
import torch.distributed as dista

from colossalai.zero.low_level._utils import has_inf_or_nan

if fp8_communication:
# if False:
if has_inf_or_nan(input_):
print("input has nan")
exit(0)
input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format="e5m2")
if has_inf_or_nan(ret):
import pdb

pdb.set_trace()
print("cast has nan")
# exit(0)
dista.barrier()
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
input_ = ret.view(torch.uint8)
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale = torch.tensor(scale, dtype=torch.float32).to(input_.device)
# import torch.distributed as dista
# if dista.get_rank()==0:
# import pdb
# pdb.set_trace()
# dista.barrier()
scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)]

scale = torch.tensor(scale).to(input_.device)
Expand All @@ -969,24 +907,10 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for
for output, scale in zip(tensor_list, scale_list):
output = output.view(fp8_type)
output = cast_from_fp8(output, scale, input_type)
if has_inf_or_nan(output) and dista.get_rank() == 0:
print("casted_output has nan")
import pdb

pdb.set_trace()
dista.barrier()

cast_tensor_list.append(output)

output = torch.cat(cast_tensor_list, dim=dim).contiguous()

if has_inf_or_nan(output):
print("output has nan")
exit(0)
# import pdb
# pdb.set_trace()
dista.barrier()

else:
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
Expand Down Expand Up @@ -1020,33 +944,14 @@ def _reduce_scatter(input_, dim=1, process_group=None):
return output


def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
if fp8_communication:
input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
input_ = ret.view(torch.uint8)
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
dist.all_gather(scale_list, scale, group=group)
cast_tensor_list = []
for output, scale in zip(output_list, scale_list):
output = output.view(fp8_type)
output = cast_from_fp8(output, scale, input_type)
cast_tensor_list.append(output)
output_list = cast_tensor_list
else:
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()


def _all_to_all_single(
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
):
def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
Expand All @@ -1058,24 +963,8 @@ def _all_to_all_single(
.contiguous()
)

if fp8_communication:
input_type = input_t.dtype
ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format)
fp8_type = ret.dtype
input_t = ret.view(torch.uint8)
output = torch.empty_like(input_t)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)]
dist.all_to_all_single(output, input_t, group=group)
dist.all_gather(scale_list, scale, group=group)
cast_tensor_list = []
for output_part, scale in zip(output, scale_list):
output_part = output_part.view(fp8_type)
output_part = cast_from_fp8(output_part, scale, input_type)
cast_tensor_list.append(output_part)
output = torch.stack(cast_tensor_list, dim=0)
else:
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)

if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()
Expand Down Expand Up @@ -1143,5 +1032,5 @@ def reduce_backward(input_, process_group, fp8_communication=False):
return _ReduceBackward.apply(input_, process_group, fp8_communication)


def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
18 changes: 5 additions & 13 deletions colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
fp8_communication: bool = False,
**kwargs,
):
super().__init__(weight=weight, bias_=bias_, **kwargs)
Expand All @@ -99,7 +98,6 @@ def __init__(
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.fp8_communication = fp8_communication

if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
Expand Down Expand Up @@ -203,12 +201,10 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
bias = self.bias if not self.skip_bias_add else None

if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
elif self.seq_parallel_mode == "split_gather":
input_parallel = gather_forward_reducescatter_backward(
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
input_parallel, self.process_group, self.seq_parallel_dim
)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
elif self.seq_parallel_mode == "ring":
Expand Down Expand Up @@ -268,7 +264,6 @@ def __init__(
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1,
fp8_communication: bool = False,
):
super().__init__()

Expand All @@ -283,7 +278,6 @@ def __init__(
self.seq_parallel_mode = seq_parallel_mode
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
self.fp8_communication = fp8_communication

if skip_bias_add and not bias:
raise ValueError("cannot skip bias addition if bias is None")
Expand Down Expand Up @@ -404,9 +398,7 @@ def forward(self, input_: Tensor) -> Tensor:
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
)
input_ = split_forward_gather_backward(
input_, dim=-1, process_group=self.process_group, fp8_comm=self.fp8_communication
)
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)

if self.stream_chunk_num > 1:
if self.training:
Expand All @@ -426,11 +418,11 @@ def forward(self, input_: Tensor) -> Tensor:
else:
if self.seq_parallel_mode is None:
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
output = reduce_forward(output_parallel, self.process_group)
elif self.seq_parallel_mode == "split_gather":
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
output = reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
output_parallel, self.process_group, self.seq_parallel_dim
)
elif self.seq_parallel_mode == "ring":
output = linear_reducescatter_forward_gather_backward(
Expand Down
Loading

0 comments on commit 6a20f07

Please sign in to comment.