From f6f14c8db5143055c83db5a437e12b4dfe2bb4f7 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 26 May 2023 15:09:49 +0800 Subject: [PATCH 1/6] add dropout layer, add dropout test --- colossalai/shardformer/layer/__init__.py | 0 colossalai/shardformer/layer/dropout.py | 39 +++++++++++++++++++ colossalai/shardformer/shard/slicer.py | 27 +++++++------ colossalai/shardformer/test/test.py | 49 +++++++++++++++++------- 4 files changed, 89 insertions(+), 26 deletions(-) create mode 100644 colossalai/shardformer/layer/__init__.py create mode 100644 colossalai/shardformer/layer/dropout.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py new file mode 100644 index 000000000000..c0c9d27ea6da --- /dev/null +++ b/colossalai/shardformer/layer/dropout.py @@ -0,0 +1,39 @@ +import os +import time + +import torch +import torch.nn as nn + + +class SeedManager: + + def __init__(self): + self.original_state = torch.cuda.get_rng_state() + seed = int(f"{int(time.time())}{os.environ['RANK']}") + print(seed) + torch.cuda.manual_seed(int(seed)) + self.dropout_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.original_state) + + def dropout_mode(self): + self.original_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.dropout_state) + + def origin_mode(self): + self.dropout_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(self.original_state) + + +_seed_manager = SeedManager() + + +class Dropout1D(nn.Dropout): + + def __init__(self, p=0.5, inplace=False): + super().__init__(p, inplace) + + def forward(self, input): + _seed_manager.dropout_mode() + input = super().forward(input) + _seed_manager.origin_mode() + return input diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 957ce1f85814..45218f24b423 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -94,10 +94,11 @@ def slice_1d( Returns: :class:`torch.Tensor`: The sliced tensor """ - delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = self.shardconfig.rank * delta - up_idx = down_idx + delta - return tensor[down_idx:up_idx].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + # delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + # down_idx = self.shardconfig.rank * delta + # up_idx = down_idx + delta + # return tensor[down_idx:up_idx].contiguous() def slice_col( self, @@ -113,10 +114,11 @@ def slice_col( :class:`torch.Tensor`: The sliced tensor """ - delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = self.shardconfig.rank * delta - up_idx = down_idx + delta - return tensor[down_idx:up_idx, :].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + # delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + # down_idx = self.shardconfig.rank * delta + # up_idx = down_idx + delta + # return tensor[down_idx:up_idx, :].contiguous() def slice_row( self, @@ -131,7 +133,8 @@ def slice_row( Returns: :class:`torch.Tensor`: The sliced tensor """ - delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = self.shardconfig.rank * delta - up_idx = down_idx + delta - return tensor[:, down_idx:up_idx].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + # delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + # down_idx = self.shardconfig.rank * delta + # up_idx = down_idx + delta + # return tensor[:, down_idx:up_idx].contiguous() diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 202208123ced..b51d89e7bf59 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,13 +1,15 @@ import os +import random import torch import torch.nn as nn from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler import colossalai +from colossalai.shardformer.layer.dropout import Dropout1D from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.utils import get_current_device, print_rank_0 @@ -30,36 +32,40 @@ def load_data(): # tokenized_datasets=tokenized_datasets.rename_column("label","labels") tokenized_datasets.set_format("torch") - train_dataset = tokenized_datasets["train"].select(range(500)) - test_dataset = tokenized_datasets["test"].select(range(100)) + train_dataset = tokenized_datasets["train"] + test_dataset = tokenized_datasets["test"] datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) - eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) + eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) return train_dataloader, eval_dataloader -def inference(model: nn.Module): - print(model) +def inference(model: nn.Module, args): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") + model.eval() model.to("cuda") outputs = model(**inputs) print(outputs) -def train(model: nn.Module, num_epoch: int = 2): +def train(model: nn.Module, args, num_epoch: int = 3): train_dataloader, eval_dataloader = load_data() optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - progress_bar = tqdm(range((num_epoch) * len(train_dataloader))) - criterion = nn.CrossEntropyLoss() + num_training = num_epoch * len(train_dataloader) + progress_bar = tqdm(range(num_training)) + lr_scheduler = get_scheduler(name="linear", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_training) + best_test_loss = float("inf") model.to("cuda") model.train() for epoch in range(num_epoch): progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") - for batch in train_dataloader: optimizer.zero_grad() batch = {k: v.to('cuda') for k, v in batch.items()} @@ -67,6 +73,7 @@ def train(model: nn.Module, num_epoch: int = 2): loss = outputs.loss loss.backward() optimizer.step() + # lr_scheduler.step() progress_bar.update(1) train_loss = loss @@ -75,16 +82,28 @@ def train(model: nn.Module, num_epoch: int = 2): batch = {k: v.to('cuda') for k, v in batch.items()} outputs = model(**batch) # loss = outputs.loss + assert not torch.isnan(outputs.loss), f"{batch}" loss += outputs.loss.item() # loss = criterion(outputs.logits, batch["input_ids"]) test_loss = loss / len(eval_dataloader) print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") + if test_loss < best_test_loss: + best_test_loss = test_loss + torch.save(model.state_dict(), "./checkpoints/best_model.pth") + + +def dropout(model: nn.Module, args, input: torch.Tensor() = torch.randn(5, 4)): + input = input.to("cuda") + m = Dropout1D(0.3).to("cuda") + for i in range(2): + print(f"Output: {m(input)}") + print(torch.randn(1)) if __name__ == "__main__": args = get_args() - colossalai.launch_from_torch(config=args.config) model = BertForMaskedLM.from_pretrained("bert-base-uncased") + colossalai.launch_from_torch(config=args.config) shard_config = ShardConfig( rank=int(str(get_current_device()).split(':')[-1]), world_size=int(os.environ['WORLD_SIZE']), @@ -92,6 +111,8 @@ def train(model: nn.Module, num_epoch: int = 2): sharded_model = shard_model(model, shard_config) if args.mode == "train": - train(sharded_model) + train(sharded_model, args) elif args.mode == "inference": - inference(sharded_model) + inference(sharded_model, args) + elif args.mode == 'dropout': + dropout(sharded_model, args) From 1848b853105b82b1838c66b9e279f45bc9ac3710 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 29 May 2023 10:26:34 +0800 Subject: [PATCH 2/6] modify seed manager as context manager --- colossalai/shardformer/layer/dropout.py | 43 ++++++++++++++++++------- colossalai/shardformer/shard/slicer.py | 12 ------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index c0c9d27ea6da..acc114029ac1 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,27 +1,47 @@ import os import time +from contextlib import contextmanager import torch import torch.nn as nn class SeedManager: + """ + This class is a random state manager to change random state for different random seed. + + """ def __init__(self): - self.original_state = torch.cuda.get_rng_state() + original_state = torch.cuda.get_rng_state() seed = int(f"{int(time.time())}{os.environ['RANK']}") - print(seed) torch.cuda.manual_seed(int(seed)) self.dropout_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.original_state) + torch.cuda.set_rng_state(original_state) - def dropout_mode(self): - self.original_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.dropout_state) + def set_mode(self, rng_state): + torch.cuda.set_rng_state(rng_state) - def origin_mode(self): - self.dropout_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.original_state) + def get_current_mode(self): + current_state = torch.cuda.get_rng_state() + return current_state + + @contextmanager + def dropout_mode(self): + """ + This is a context manager to change the dropout state and recover the original state. + + Usage: + :: + >>> with _seed_manager.dropout_mode(): + >>> input = super().forward(input) + """ + try: + current_mode = self.get_current_mode() + yield self.set_mode(self.dropout_state) + finally: + self.dropout_state = self.get_current_mode() + self.set_mode(current_mode) _seed_manager = SeedManager() @@ -33,7 +53,6 @@ def __init__(self, p=0.5, inplace=False): super().__init__(p, inplace) def forward(self, input): - _seed_manager.dropout_mode() - input = super().forward(input) - _seed_manager.origin_mode() + with _seed_manager.dropout_mode(): + input = super().forward(input) return input diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 45218f24b423..26053b9f7408 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -95,10 +95,6 @@ def slice_1d( :class:`torch.Tensor`: The sliced tensor """ return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - # delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - # down_idx = self.shardconfig.rank * delta - # up_idx = down_idx + delta - # return tensor[down_idx:up_idx].contiguous() def slice_col( self, @@ -115,10 +111,6 @@ def slice_col( """ return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - # delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - # down_idx = self.shardconfig.rank * delta - # up_idx = down_idx + delta - # return tensor[down_idx:up_idx, :].contiguous() def slice_row( self, @@ -134,7 +126,3 @@ def slice_row( :class:`torch.Tensor`: The sliced tensor """ return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() - # delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - # down_idx = self.shardconfig.rank * delta - # up_idx = down_idx + delta - # return tensor[:, down_idx:up_idx].contiguous() From f01bb253e81986e881494ead804cafd808c562d9 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 29 May 2023 14:42:53 +0800 Subject: [PATCH 3/6] add a copy of col_nn.layer --- colossalai/nn/layer/parallel_1d/_operation.py | 1 - colossalai/nn/layer/parallel_1d/layers.py | 9 +- colossalai/shardformer/layer/_operation.py | 97 ++ colossalai/shardformer/layer/layers.py | 1043 +++++++++++++++++ colossalai/shardformer/policies/basepolicy.py | 2 - colossalai/shardformer/policies/bert.py | 2 +- 6 files changed, 1144 insertions(+), 10 deletions(-) create mode 100644 colossalai/shardformer/layer/_operation.py create mode 100644 colossalai/shardformer/layer/layers.py diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index c5e33fd497cd..300baf9c12ba 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -73,7 +73,6 @@ def backward(ctx, grad_output): total_input = input grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 0ee3b4fcb502..406173a18c60 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -469,8 +469,7 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) - self.out_features_per_partition = out_features + self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. @@ -613,8 +612,7 @@ def __init__(self, raise ValueError('cannot skip bias addition if bias is None') # Divide the weight matrix along the last dimension. - # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) - self.input_size_per_partition = in_features + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. @@ -886,8 +884,7 @@ def __init__(self, tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings_per_partition = num_embeddings + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py new file mode 100644 index 000000000000..e817ea3ebbee --- /dev/null +++ b/colossalai/shardformer/layer/_operation.py @@ -0,0 +1,97 @@ +import torch +import torch.distributed as dist + +from colossalai.core import global_context as gpc + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, + bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py new file mode 100644 index 000000000000..f5123885bbe4 --- /dev/null +++ b/colossalai/shardformer/layer/layers.py @@ -0,0 +1,1043 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from collections import OrderedDict +from typing import Callable, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule +from colossalai.nn.layer.parallel_1d._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) +from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition +from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.utils.cuda import get_current_device + +from ._operation import linear_with_async_comm + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +# @LAYERS.register_module +class Linear1D(ColossalaiModule): + r"""Linear layer for 1D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): Whether to call all-gather on output, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + parallel_input = get_parallel_input() + if not parallel_input and not gather_output: + layer = Linear1D_Col(in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + else: + layer = Linear1D_Row(in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + super().__init__(layer) + + +# @LAYERS.register_module +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) + + +# @LAYERS.register_module +class Classifier1D(ParallelLayer): + r"""RowLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = False + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +# @LAYERS.register_module +class VocabParallelClassifier1D(ParallelLayer): + r"""ColLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.gather_output = gather_output + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +# @LAYERS.register_module +class Linear1D_Col(ParallelLayer): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) + self.out_features_per_partition = out_features + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + is_parallel_output = not self.gather_output + set_parallel_input(is_parallel_output) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +# @LAYERS.register_module +class Linear1D_Row(ParallelLayer): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) + self.input_size_per_partition = in_features + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=gpc.get_group(ParallelMode.PARALLEL_1D), + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +# @LAYERS.register_module +class Embedding1D(ParallelLayer): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + + return output + + +# @LAYERS.register_module +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings_per_partition = num_embeddings + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +# @LAYERS.register_module +class Dropout1D(ParallelLayer): + """Dropout layer of 1D parallelism. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.parallel_input = get_parallel_input() + self.p = p + self.inplace = inplace + + def forward(self, input_: Tensor) -> Tensor: + if self.parallel_input: + with seed(ParallelMode.TENSOR): + output = F.dropout(input_, self.p, self.training, self.inplace) + else: + output = F.dropout(input_, self.p, self.training, self.inplace) + return output + + +# @LAYERS.register_module +class PatchEmbedding1D(ColossalaiModule): + """ + 2D Image to Patch Embedding + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param in_chans: number of channels of input image + :type in_chans: int + :param embed_size: size of embedding + :type embed_size: int + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + :param position_embed_initializer: The initializer of position embedding, defaults to zero + :type position_embed_initializer: typing.Callable, optional + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + embed = VanillaPatchEmbedding(img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer) + super().__init__(embed) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + for key in param_keys: + param = state_dict.pop(key, None) + if param is not None: + local_state[key] = param + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index a5cc0bc68df6..2eb7eb29e1a4 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -7,8 +7,6 @@ import torch.nn as nn from transformers import AutoConfig -import colossalai.nn as col_nn - @dataclass class Argument: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5d91d8ddc766..e48b36ba69fc 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead -import colossalai.nn as col_nn +import colossalai.shardformer.layer.layers as col_nn from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer From d26625b6d5871487c3c43d5d8c79e2866308e901 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 29 May 2023 17:35:48 +0800 Subject: [PATCH 4/6] add dist_crossentropy loss; separate module test --- .../shardformer/layer/dist_crossentropy.py | 92 +++++++++++++++++++ colossalai/shardformer/test/module_test.py | 50 ++++++++++ colossalai/shardformer/test/test.py | 13 +-- 3 files changed, 144 insertions(+), 11 deletions(-) create mode 100644 colossalai/shardformer/layer/dist_crossentropy.py create mode 100644 colossalai/shardformer/test/module_test.py diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py new file mode 100644 index 000000000000..3c70c969e7a5 --- /dev/null +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -0,0 +1,92 @@ +from typing import Any + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + + +class DistCrossEntropy(Function): + """ + Overwrite the forward and backward function to calculate the cross entropy loss before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward(ctx, vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + r""" + Calculate the cross entropy loss before gather + + Args: + vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is + [batch_size, seq_len, vocab_size] + labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is + [batch_size, seq_len] + + Returns: + :class:`torch.Tensor`: The cross entropy loss + """ + partion_vocab_size = vocab_logits.shape[-1] + ctx.vocab_size = partion_vocab_size * dist.get_world_size() + + #get the mask to filter the labels not in local device + rank = dist.get_rank() + world_size = dist.get_world_size() + delta = (partion_vocab_size * dist.get_world_size() + world_size - 1) // world_size + down_shreshold = rank * delta + up_shreshold = down_shreshold + delta + # [down, up) => false, other => true + mask = (labels < down_shreshold) | (labels >= up_shreshold) + mask_labels = labels.clone() - down_shreshold + # the default ignore index is -100 + mask_labels[mask] = -100 + + # reshape the vocab_logits to [bath_size * seq_len, vocab_size] + # reshape the labels to [bath_size * seq_len] + vocab_logits_2d = vocab_logits.view(-1, partion_vocab_size) + labels_1d = mask_labels.view(-1) + + exp_vocab_logits_2d = torch.exp(vocab_logits_2d) + sum_exp_vocab_logits_2d = torch.sum(exp_vocab_logits_2d, dim=-1) + dist.all_reduce(sum_exp_vocab_logits_2d, op=dist.ReduceOp.SUM) + + log_softmax_vocab_logits_2d = torch.log(exp_vocab_logits_2d / sum_exp_vocab_logits_2d.unsqueeze(-1)) + loss = F.nll_loss(log_softmax_vocab_logits_2d, labels_1d, reduction="none") # the ignore index is -100 + loss_list = [torch.empty_like(loss) for _ in range(world_size)] + loss_list[rank] = loss + dist.all_gather(loss_list, loss) + loss = torch.cat(loss_list, dim=0) + non_zero_count = torch.sum(loss != 0) + loss = loss.sum() / non_zero_count + + log_softmax_vocab_logits = log_softmax_vocab_logits_2d.view(*vocab_logits.shape) + ctx.save_for_backward(log_softmax_vocab_logits, mask, labels_1d) + return loss + + @staticmethod + def backward(ctx: Any, grad_outputs: Any) -> Any: + # retrieve the saved tensors and set the ignore to 0 to avoid out out index + log_softmax_vocab_logits, mask, labels_1d = ctx.saved_tensors + labels_1d[labels_1d == -100] = 0 + + # logsoftmax as the grad_input + grad_input = log_softmax_vocab_logits + partion_vocab_size = log_softmax_vocab_logits.shape[-1] + grad_2d = grad_input.view(-1, partion_vocab_size) + + # set a mask to update the gradient of the labels in local device + arange_1d = torch.arange(start=0, end=grad_2d.shape[0], device=grad_2d.device) + logsoftmax_update = 1.0 - mask.view(-1).float() + grad_2d[arange_1d, labels_1d] -= logsoftmax_update + + # calculate the grad_input + grad_input.mul_(grad_outputs.unsqueeze(-1)) + + return grad_input, None, None + + +def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels) diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py new file mode 100644 index 000000000000..8c23e8faac89 --- /dev/null +++ b/colossalai/shardformer/test/module_test.py @@ -0,0 +1,50 @@ +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import colossalai +from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.shardformer.layer.dropout import Dropout1D + + +def get_args(): + parser = colossalai.get_default_parser() + parser.add_argument("--module", type=str, default='distloss') + return parser.parse_args() + + +def test_dist_crossentropy(): + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (1, 4)).repeat(2, 1) + + pred_ = pred.view(-1, 8) + labels_ = labels.view(-1) + loss = F.cross_entropy(pred_, labels_) + loss.backward() + print(f"normal loss:{loss}") + + pred = pred.chunk(2, -1)[int(os.environ['RANK'])] + loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) + loss.backward() + print(f"dist loss:{loss:.4f}") + + +def test_dropout(): + input = torch.randn(5, 4).to("cuda") + m = Dropout1D(p=0.2).to("cuda") + for i in range(2): + print(f"Output: {m(input)}") + print(torch.randn(1)) + + +if __name__ == '__main__': + args = get_args() + colossalai.launch_from_torch(config={}) + if args.module == 'distloss': + test_dist_crossentropy() + elif args.module == 'dropout': + test_dropout() + else: + print("not implemented yet") diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index b51d89e7bf59..9e859a699a7d 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -9,7 +9,6 @@ from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler import colossalai -from colossalai.shardformer.layer.dropout import Dropout1D from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.utils import get_current_device, print_rank_0 @@ -92,14 +91,6 @@ def train(model: nn.Module, args, num_epoch: int = 3): torch.save(model.state_dict(), "./checkpoints/best_model.pth") -def dropout(model: nn.Module, args, input: torch.Tensor() = torch.randn(5, 4)): - input = input.to("cuda") - m = Dropout1D(0.3).to("cuda") - for i in range(2): - print(f"Output: {m(input)}") - print(torch.randn(1)) - - if __name__ == "__main__": args = get_args() model = BertForMaskedLM.from_pretrained("bert-base-uncased") @@ -114,5 +105,5 @@ def dropout(model: nn.Module, args, input: torch.Tensor() = torch.randn(5, 4)): train(sharded_model, args) elif args.mode == "inference": inference(sharded_model, args) - elif args.mode == 'dropout': - dropout(sharded_model, args) + else: + raise NotImplementedError From efcf93a38d4ffc8baafa0f7c9b36256864dbd266 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 29 May 2023 17:39:22 +0800 Subject: [PATCH 5/6] polish the code --- colossalai/shardformer/test/module_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py index 8c23e8faac89..f26e73aea69a 100644 --- a/colossalai/shardformer/test/module_test.py +++ b/colossalai/shardformer/test/module_test.py @@ -23,9 +23,9 @@ def test_dist_crossentropy(): labels_ = labels.view(-1) loss = F.cross_entropy(pred_, labels_) loss.backward() - print(f"normal loss:{loss}") + print(f"normal loss:{loss:.4f}") - pred = pred.chunk(2, -1)[int(os.environ['RANK'])] + pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) loss.backward() print(f"dist loss:{loss:.4f}") From 5e4892356663edcc6a0095be38011297c2d36686 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 30 May 2023 18:04:41 +0800 Subject: [PATCH 6/6] fix dist crossentropy loss --- colossalai/shardformer/README.md | 19 +++ .../shardformer/layer/dist_crossentropy.py | 109 ++++++++++-------- colossalai/shardformer/model/modeling_bert.py | 10 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/test/module_test.py | 4 +- colossalai/shardformer/test/test.py | 5 +- 6 files changed, 92 insertions(+), 57 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 55b6aa75ef84..3394e9457da3 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -257,3 +257,22 @@ CustomPolicy(Policy): - CLASS `Slicer`: This class is used to slice tensor according to policy. + + + 3. DistCrossEntropy Loss + - Overview + + In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is: + $$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$ + + alse can be represented as: + + $$ loss = \log(\sum_i\exp(x[i])) - x[class]$$ + + - Step + + - First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large + + - Get a mask to mask the logits not in the local device + + - Caculate the loss according to the second formula diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 3c70c969e7a5..1869594670ce 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -1,5 +1,3 @@ -from typing import Any - import torch import torch.distributed as dist import torch.nn as nn @@ -8,7 +6,7 @@ class DistCrossEntropy(Function): - """ + r""" Overwrite the forward and backward function to calculate the cross entropy loss before gather Args: @@ -16,9 +14,14 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): r""" - Calculate the cross entropy loss before gather + Calculate the cross entropy loss before gather, the origin loss function is as follows: + loss = -log(exp(x[class])/sum(exp(x[i])) + and can be rewrite as: + loss = log(sum(exp(x[i])) - x[class] + + To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i] Args: vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is @@ -29,63 +32,73 @@ def forward(ctx, vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tens Returns: :class:`torch.Tensor`: The cross entropy loss """ - partion_vocab_size = vocab_logits.shape[-1] - ctx.vocab_size = partion_vocab_size * dist.get_world_size() + # get the max + logits_max = torch.max(vocab_logits, dim=-1)[0] + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX) + + # minus the max to avoid the result of sum of exp is too large and the log is nan + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) - #get the mask to filter the labels not in local device + # mask the target in the local device + partition_vocab_size = vocab_logits.size()[-1] rank = dist.get_rank() world_size = dist.get_world_size() - delta = (partion_vocab_size * dist.get_world_size() + world_size - 1) // world_size + global_vocab_size = partition_vocab_size * world_size + + # [down, up) => false, other device and -100 => true + delta = (global_vocab_size + world_size - 1) // world_size down_shreshold = rank * delta up_shreshold = down_shreshold + delta - # [down, up) => false, other => true - mask = (labels < down_shreshold) | (labels >= up_shreshold) - mask_labels = labels.clone() - down_shreshold - # the default ignore index is -100 - mask_labels[mask] = -100 + mask = (target < down_shreshold) | (target >= up_shreshold) + masked_target = target.clone() - down_shreshold + masked_target[mask] = 0 + # reshape the logist and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] - vocab_logits_2d = vocab_logits.view(-1, partion_vocab_size) - labels_1d = mask_labels.view(-1) - - exp_vocab_logits_2d = torch.exp(vocab_logits_2d) - sum_exp_vocab_logits_2d = torch.sum(exp_vocab_logits_2d, dim=-1) - dist.all_reduce(sum_exp_vocab_logits_2d, op=dist.ReduceOp.SUM) - - log_softmax_vocab_logits_2d = torch.log(exp_vocab_logits_2d / sum_exp_vocab_logits_2d.unsqueeze(-1)) - loss = F.nll_loss(log_softmax_vocab_logits_2d, labels_1d, reduction="none") # the ignore index is -100 - loss_list = [torch.empty_like(loss) for _ in range(world_size)] - loss_list[rank] = loss - dist.all_gather(loss_list, loss) - loss = torch.cat(loss_list, dim=0) - non_zero_count = torch.sum(loss != 0) - loss = loss.sum() / non_zero_count - - log_softmax_vocab_logits = log_softmax_vocab_logits_2d.view(*vocab_logits.shape) - ctx.save_for_backward(log_softmax_vocab_logits, mask, labels_1d) + logits_2d = vocab_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + + # extract the x[class] and set the x[other device] to zero + pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), + masked_target_1d] + pred_logits_1d = pred_logits_1d.clone().contiguous() + pred_logits = pred_logits_1d.view_as(target) + pred_logits[mask] = 0.0 + + # allreduce the get all x(i,y) + dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM) + exp_logits = vocab_logits + torch.exp(vocab_logits, out=exp_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM) + + # calculate the loss + # loss = log(sum(exp(x[i]))) - x[class] + loss = torch.log(sum_exp_logits) - pred_logits + loss = torch.sum(loss).div_(loss.numel()) + + # caculate the softmax + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, mask, masked_target_1d) + return loss @staticmethod - def backward(ctx: Any, grad_outputs: Any) -> Any: - # retrieve the saved tensors and set the ignore to 0 to avoid out out index - log_softmax_vocab_logits, mask, labels_1d = ctx.saved_tensors - labels_1d[labels_1d == -100] = 0 - - # logsoftmax as the grad_input - grad_input = log_softmax_vocab_logits - partion_vocab_size = log_softmax_vocab_logits.shape[-1] - grad_2d = grad_input.view(-1, partion_vocab_size) + def backward(ctx, grad_output): + # retrieve the saved tensors + exp_logits, mask, masked_target_1d = ctx.saved_tensors - # set a mask to update the gradient of the labels in local device - arange_1d = torch.arange(start=0, end=grad_2d.shape[0], device=grad_2d.device) - logsoftmax_update = 1.0 - mask.view(-1).float() - grad_2d[arange_1d, labels_1d] -= logsoftmax_update + # use exp logits as the input grad + grad_logits = exp_logits + partion_vocab_size = grad_logits.shape[-1] + grad_logits_2d = grad_logits.view(-1, partion_vocab_size) - # calculate the grad_input - grad_input.mul_(grad_outputs.unsqueeze(-1)) + update = 1.0 - mask.view(-1).float() + grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update - return grad_input, None, None + grad_logits.mul_(grad_output.unsqueeze(dim=-1)) + return grad_logits, None, None def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 6741ae866991..bd07ab80c00d 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -6,6 +6,8 @@ from transformers import BertForMaskedLM from transformers.models.bert.modeling_bert import MaskedLMOutput +from ..layer.dist_crossentropy import applyDistCrossEntropy + class BertForMaskedLM_(BertForMaskedLM): @@ -47,11 +49,11 @@ def forward( masked_lm_loss = None - # if input_ids is not None: - # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels) + # if labels is not None: + # loss_fct = CrossEntropyLoss() # -100 index = padding token + # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores,) + outputs[2:] diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index e48b36ba69fc..ab77b29f71f4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -142,7 +142,7 @@ def unembedding() -> List: weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - gather_output=True, + # gather_output=True, ) ] diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py index f26e73aea69a..83dc7ec6cf4a 100644 --- a/colossalai/shardformer/test/module_test.py +++ b/colossalai/shardformer/test/module_test.py @@ -23,12 +23,12 @@ def test_dist_crossentropy(): labels_ = labels.view(-1) loss = F.cross_entropy(pred_, labels_) loss.backward() - print(f"normal loss:{loss:.4f}") + print(f"normal loss:{loss}") pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) loss.backward() - print(f"dist loss:{loss:.4f}") + print(f"dist loss:{loss}") def test_dropout(): diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 9e859a699a7d..b896fd4a4020 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -19,6 +19,7 @@ def get_args(): parser = colossalai.get_default_parser() parser.add_argument("--mode", type=str, default='inference') + parser.add_argument("--save_model", action='store_true') return parser.parse_args() @@ -72,7 +73,7 @@ def train(model: nn.Module, args, num_epoch: int = 3): loss = outputs.loss loss.backward() optimizer.step() - # lr_scheduler.step() + lr_scheduler.step() progress_bar.update(1) train_loss = loss @@ -86,7 +87,7 @@ def train(model: nn.Module, args, num_epoch: int = 3): # loss = criterion(outputs.logits, batch["input_ids"]) test_loss = loss / len(eval_dataloader) print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") - if test_loss < best_test_loss: + if args.save_model and test_loss < best_test_loss: best_test_loss = test_loss torch.save(model.state_dict(), "./checkpoints/best_model.pth")