forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[shardformer] add Dropout layer support different dropout pattern (hp…
…caitech#3856) * add dropout layer, add dropout test * modify seed manager as context manager * add a copy of col_nn.layer * add dist_crossentropy loss; separate module test * polish the code * fix dist crossentropy loss
- Loading branch information
1 parent
d0a7a15
commit d2ac32a
Showing
14 changed files
with
1,413 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
import torch.distributed as dist | ||
|
||
from colossalai.core import global_context as gpc | ||
|
||
try: | ||
import fused_mix_prec_layer_norm_cuda | ||
except: | ||
fused_mix_prec_layer_norm_cuda = None | ||
|
||
|
||
class FusedLayerNormAffineFunction1D(torch.autograd.Function): | ||
r"""Layernorm | ||
Args: | ||
input: input matrix. | ||
weight: weight matrix. | ||
bias: bias matrix. | ||
normalized_shape: input shape from an expected input of size. | ||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` | ||
If a single integer is used, it is treated as a singleton list, and this module will | ||
normalize over the last dimension which is expected to be of that specific size. | ||
eps: a value added to the denominator for numerical stability | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, input, weight, bias, normalized_shape, eps): | ||
ctx.normalized_shape = normalized_shape | ||
ctx.eps = eps | ||
input_ = input.contiguous() | ||
weight_ = weight.contiguous() | ||
bias_ = bias.contiguous() | ||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, | ||
bias_, ctx.eps) | ||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors | ||
grad_input = grad_weight = grad_bias = None | ||
grad_input, grad_weight, grad_bias \ | ||
= fused_mix_prec_layer_norm_cuda.backward_affine( | ||
grad_output.contiguous(), mean, invvar, | ||
input_, ctx.normalized_shape, | ||
weight_, bias_, ctx.eps) | ||
|
||
return grad_input, grad_weight, grad_bias, None, None | ||
|
||
|
||
class LinearWithAsyncCommunication(torch.autograd.Function): | ||
""" | ||
Linear layer execution with asynchronous communication in backprop. | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): | ||
ctx.save_for_backward(input_, weight) | ||
ctx.use_bias = bias is not None | ||
ctx.parallel_mode = parallel_mode | ||
ctx.async_grad_allreduce = async_grad_allreduce | ||
|
||
output = torch.matmul(input_, weight.t()) | ||
if bias is not None: | ||
output = output + bias | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input, weight = ctx.saved_tensors | ||
use_bias = ctx.use_bias | ||
|
||
total_input = input | ||
grad_input = grad_output.matmul(weight) | ||
grad_output = grad_output.contiguous() | ||
# Convert the tensor shapes to 2D for execution compatibility | ||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) | ||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) | ||
|
||
if ctx.async_grad_allreduce: | ||
# Asynchronous all-reduce | ||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) | ||
# Delay the start of weight gradient computation shortly (3us) to have | ||
# all-reduce scheduled first and have GPU resources allocated | ||
_ = torch.empty(1, device=grad_output.device) + 1 | ||
|
||
grad_weight = grad_output.t().matmul(total_input) | ||
grad_bias = grad_output.sum(dim=0) if use_bias else None | ||
|
||
if ctx.async_grad_allreduce: | ||
handle.wait() | ||
|
||
return grad_input, grad_weight, grad_bias, None, None, None | ||
|
||
|
||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): | ||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Function | ||
|
||
|
||
class DistCrossEntropy(Function): | ||
r""" | ||
Overwrite the forward and backward function to calculate the cross entropy loss before gather | ||
Args: | ||
Function (:class:`torch.autograd.Function`): default | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): | ||
r""" | ||
Calculate the cross entropy loss before gather, the origin loss function is as follows: | ||
loss = -log(exp(x[class])/sum(exp(x[i])) | ||
and can be rewrite as: | ||
loss = log(sum(exp(x[i])) - x[class] | ||
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i] | ||
Args: | ||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is | ||
[batch_size, seq_len, vocab_size] | ||
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is | ||
[batch_size, seq_len] | ||
Returns: | ||
:class:`torch.Tensor`: The cross entropy loss | ||
""" | ||
# get the max | ||
logits_max = torch.max(vocab_logits, dim=-1)[0] | ||
dist.all_reduce(logits_max, op=dist.ReduceOp.MAX) | ||
|
||
# minus the max to avoid the result of sum of exp is too large and the log is nan | ||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) | ||
|
||
# mask the target in the local device | ||
partition_vocab_size = vocab_logits.size()[-1] | ||
rank = dist.get_rank() | ||
world_size = dist.get_world_size() | ||
global_vocab_size = partition_vocab_size * world_size | ||
|
||
# [down, up) => false, other device and -100 => true | ||
delta = (global_vocab_size + world_size - 1) // world_size | ||
down_shreshold = rank * delta | ||
up_shreshold = down_shreshold + delta | ||
mask = (target < down_shreshold) | (target >= up_shreshold) | ||
masked_target = target.clone() - down_shreshold | ||
masked_target[mask] = 0 | ||
|
||
# reshape the logist and target | ||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size] | ||
# reshape the labels to [bath_size * seq_len] | ||
logits_2d = vocab_logits.view(-1, partition_vocab_size) | ||
masked_target_1d = masked_target.view(-1) | ||
|
||
# extract the x[class] and set the x[other device] to zero | ||
pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), | ||
masked_target_1d] | ||
pred_logits_1d = pred_logits_1d.clone().contiguous() | ||
pred_logits = pred_logits_1d.view_as(target) | ||
pred_logits[mask] = 0.0 | ||
|
||
# allreduce the get all x(i,y) | ||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM) | ||
exp_logits = vocab_logits | ||
torch.exp(vocab_logits, out=exp_logits) | ||
sum_exp_logits = torch.sum(exp_logits, dim=-1) | ||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM) | ||
|
||
# calculate the loss | ||
# loss = log(sum(exp(x[i]))) - x[class] | ||
loss = torch.log(sum_exp_logits) - pred_logits | ||
loss = torch.sum(loss).div_(loss.numel()) | ||
|
||
# caculate the softmax | ||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) | ||
ctx.save_for_backward(exp_logits, mask, masked_target_1d) | ||
|
||
return loss | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
# retrieve the saved tensors | ||
exp_logits, mask, masked_target_1d = ctx.saved_tensors | ||
|
||
# use exp logits as the input grad | ||
grad_logits = exp_logits | ||
partion_vocab_size = grad_logits.shape[-1] | ||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size) | ||
|
||
update = 1.0 - mask.view(-1).float() | ||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update | ||
|
||
grad_logits.mul_(grad_output.unsqueeze(dim=-1)) | ||
return grad_logits, None, None | ||
|
||
|
||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | ||
return DistCrossEntropy.apply(vocab_logits, labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import time | ||
from contextlib import contextmanager | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class SeedManager: | ||
""" | ||
This class is a random state manager to change random state for different random seed. | ||
""" | ||
|
||
def __init__(self): | ||
original_state = torch.cuda.get_rng_state() | ||
seed = int(f"{int(time.time())}{os.environ['RANK']}") | ||
torch.cuda.manual_seed(int(seed)) | ||
self.dropout_state = torch.cuda.get_rng_state() | ||
torch.cuda.set_rng_state(original_state) | ||
|
||
def set_mode(self, rng_state): | ||
torch.cuda.set_rng_state(rng_state) | ||
|
||
def get_current_mode(self): | ||
current_state = torch.cuda.get_rng_state() | ||
return current_state | ||
|
||
@contextmanager | ||
def dropout_mode(self): | ||
""" | ||
This is a context manager to change the dropout state and recover the original state. | ||
Usage: | ||
:: | ||
>>> with _seed_manager.dropout_mode(): | ||
>>> input = super().forward(input) | ||
""" | ||
try: | ||
current_mode = self.get_current_mode() | ||
yield self.set_mode(self.dropout_state) | ||
finally: | ||
self.dropout_state = self.get_current_mode() | ||
self.set_mode(current_mode) | ||
|
||
|
||
_seed_manager = SeedManager() | ||
|
||
|
||
class Dropout1D(nn.Dropout): | ||
|
||
def __init__(self, p=0.5, inplace=False): | ||
super().__init__(p, inplace) | ||
|
||
def forward(self, input): | ||
with _seed_manager.dropout_mode(): | ||
input = super().forward(input) | ||
return input |
Oops, something went wrong.