Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] add Dropout layer support different dropout pattern #3856

Merged
merged 6 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colossalai/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def backward(ctx, grad_output):
total_input = input
grad_input = grad_output.matmul(weight)

grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
Expand Down
9 changes: 3 additions & 6 deletions colossalai/nn/layer/parallel_1d/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,7 @@ def __init__(self,
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')

# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)

# Parameters.
# Initialize weight.
Expand Down Expand Up @@ -613,8 +612,7 @@ def __init__(self,
raise ValueError('cannot skip bias addition if bias is None')

# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)

# Parameters.
# Initialize weight.
Expand Down Expand Up @@ -886,8 +884,7 @@ def __init__(self,

tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

Expand Down
19 changes: 19 additions & 0 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,22 @@ CustomPolicy(Policy):
- CLASS `Slicer`:

This class is used to slice tensor according to policy.


3. DistCrossEntropy Loss
- Overview

In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is:
$$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$

alse can be represented as:

$$ loss = \log(\sum_i\exp(x[i])) - x[class]$$

- Step

- First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large

- Get a mask to mask the logits not in the local device

- Caculate the loss according to the second formula
Empty file.
97 changes: 97 additions & 0 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
import torch.distributed as dist

from colossalai.core import global_context as gpc

try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None


class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm

Args:
input: input matrix.
weight: weight matrix.
bias: bias matrix.
normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability
"""

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output

@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)

return grad_input, grad_weight, grad_bias, None, None


class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""

@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce

output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias

total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])

if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None


def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
105 changes: 105 additions & 0 deletions colossalai/shardformer/layer/dist_crossentropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function


class DistCrossEntropy(Function):
r"""
Overwrite the forward and backward function to calculate the cross entropy loss before gather

Args:
Function (:class:`torch.autograd.Function`): default
"""

@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewrite as:
loss = log(sum(exp(x[i])) - x[class]

To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]

Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]

Returns:
:class:`torch.Tensor`: The cross entropy loss
"""
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX)

# minus the max to avoid the result of sum of exp is too large and the log is nan
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)

# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank()
world_size = dist.get_world_size()
global_vocab_size = partition_vocab_size * world_size

# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta
up_shreshold = down_shreshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold)
masked_target = target.clone() - down_shreshold
masked_target[mask] = 0

# reshape the logist and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)

# extract the x[class] and set the x[other device] to zero
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device),
masked_target_1d]
pred_logits_1d = pred_logits_1d.clone().contiguous()
pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0

# allreduce the get all x(i,y)
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM)

# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())

# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d)

return loss

@staticmethod
def backward(ctx, grad_output):
# retrieve the saved tensors
exp_logits, mask, masked_target_1d = ctx.saved_tensors

# use exp logits as the input grad
grad_logits = exp_logits
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)

update = 1.0 - mask.view(-1).float()
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update

grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None


def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)
58 changes: 58 additions & 0 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import time
from contextlib import contextmanager

import torch
import torch.nn as nn


class SeedManager:
"""
This class is a random state manager to change random state for different random seed.

"""

def __init__(self):
original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(original_state)

def set_mode(self, rng_state):
torch.cuda.set_rng_state(rng_state)

def get_current_mode(self):
current_state = torch.cuda.get_rng_state()
return current_state

@contextmanager
def dropout_mode(self):
"""
This is a context manager to change the dropout state and recover the original state.

Usage:
::
>>> with _seed_manager.dropout_mode():
>>> input = super().forward(input)
"""
try:
current_mode = self.get_current_mode()
yield self.set_mode(self.dropout_state)
finally:
self.dropout_state = self.get_current_mode()
self.set_mode(current_mode)


_seed_manager = SeedManager()


class Dropout1D(nn.Dropout):

def __init__(self, p=0.5, inplace=False):
super().__init__(p, inplace)

def forward(self, input):
with _seed_manager.dropout_mode():
input = super().forward(input)
return input
Loading