Skip to content

Commit

Permalink
[shardformer] add Dropout layer support different dropout pattern (hp…
Browse files Browse the repository at this point in the history
…caitech#3856)

* add dropout layer, add dropout test

* modify seed manager as context manager

* add a copy of col_nn.layer

* add dist_crossentropy loss; separate module test

* polish the code

* fix dist crossentropy loss
  • Loading branch information
FoolPlayer authored and FrankLeeeee committed Jun 8, 2023
1 parent d0a7a15 commit d2ac32a
Show file tree
Hide file tree
Showing 14 changed files with 1,413 additions and 41 deletions.
1 change: 0 additions & 1 deletion colossalai/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def backward(ctx, grad_output):
total_input = input
grad_input = grad_output.matmul(weight)

grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
Expand Down
9 changes: 3 additions & 6 deletions colossalai/nn/layer/parallel_1d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,7 @@ def __init__(self,
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)

# Parameters.
# Initialize weight.
Expand Down Expand Up @@ -613,8 +612,7 @@ def __init__(self,
raise ValueError('cannot skip bias addition if bias is None')

# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)

# Parameters.
# Initialize weight.
Expand Down Expand Up @@ -886,8 +884,7 @@ def __init__(self,

tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

Expand Down
19 changes: 19 additions & 0 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,22 @@ CustomPolicy(Policy):
- CLASS `Slicer`:

This class is used to slice tensor according to policy.


3. DistCrossEntropy Loss
- Overview

In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$

alse can be represented as:

$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$

- Step

- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large

- Get a mask to mask the logits not in the local device

- Caculate the loss according to the second formula
Empty file.
97 changes: 97 additions & 0 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
import torch.distributed as dist

from colossalai.core import global_context as gpc

try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None


class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
Args:
input: input matrix.
weight: weight matrix.
bias: bias matrix.
normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability
"""

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)

return grad_input, grad_weight, grad_bias, None, None


class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""

@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce

output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias

total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])

if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None


def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
105 changes: 105 additions & 0 deletions colossalai/shardformer/layer/dist_crossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function


class DistCrossEntropy(Function):
r"""
Overwrite the forward and backward function to calculate the cross entropy loss before gather
Args:
Function (:class:`torch.autograd.Function`): default
"""

@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewrite as:
loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
:class:`torch.Tensor`: The cross entropy loss
"""
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)

# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)

# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank()
world_size = dist.get_world_size()
global_vocab_size = partition_vocab_size * world_size

# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta
up_shreshold = down_shreshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold)
masked_target = target.clone() - down_shreshold
masked_target[mask] = 0

# reshape the logist and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)

# extract the x[class] and set the x[other device] to zero
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
masked_target_1d]
pred_logits_1d = pred_logits_1d.clone().contiguous()
pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0

# allreduce the get all x(i,y)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)

# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())

# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d)

return loss

@staticmethod
def backward(ctx, grad_output):
# retrieve the saved tensors
exp_logits, mask, masked_target_1d = ctx.saved_tensors

# use exp logits as the input grad
grad_logits = exp_logits
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)

update = 1.0 - mask.view(-1).float()
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update

grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None


def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)
58 changes: 58 additions & 0 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import time
from contextlib import contextmanager

import torch
import torch.nn as nn


class SeedManager:
"""
This class is a random state manager to change random state for different random seed.
"""

def __init__(self):
original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)

def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)

def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state

@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.
Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)


_seed_manager = SeedManager()


class Dropout1D(nn.Dropout):

def __init__(self, p=0.5, inplace=False):
super().__init__(p, inplace)

def forward(self, input):
with _seed_manager.dropout_mode():
input = super().forward(input)
return input
Loading

0 comments on commit d2ac32a

Please sign in to comment.