Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] add Dropout layer support different dropout pattern #3856

Merged
merged 6 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
39 changes: 39 additions & 0 deletions colossalai/shardformer/layer/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import time

import torch
import torch.nn as nn


class SeedManager:

def __init__(self):
self.original_state = torch.cuda.get_rng_state()
seed = int(f"{int(time.time())}{os.environ['RANK']}")
print(seed)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
torch.cuda.manual_seed(int(seed))
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.original_state)

def dropout_mode(self):
self.original_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.dropout_state)

def origin_mode(self):
self.dropout_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.original_state)
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved


_seed_manager = SeedManager()


class Dropout1D(nn.Dropout):

def __init__(self, p=0.5, inplace=False):
super().__init__(p, inplace)

def forward(self, input):
_seed_manager.dropout_mode()
input = super().forward(input)
_seed_manager.origin_mode()
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
return input
27 changes: 15 additions & 12 deletions colossalai/shardformer/shard/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ def slice_1d(
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
# delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
# down_idx = self.shardconfig.rank * delta
# up_idx = down_idx + delta
# return tensor[down_idx:up_idx].contiguous()
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

def slice_col(
self,
Expand All @@ -113,10 +114,11 @@ def slice_col(
:class:`torch.Tensor`: The sliced tensor

"""
delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[down_idx:up_idx, :].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
# delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
# down_idx = self.shardconfig.rank * delta
# up_idx = down_idx + delta
# return tensor[down_idx:up_idx, :].contiguous()
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

def slice_row(
self,
Expand All @@ -131,7 +133,8 @@ def slice_row(
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
down_idx = self.shardconfig.rank * delta
up_idx = down_idx + delta
return tensor[:, down_idx:up_idx].contiguous()
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
# delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size
# down_idx = self.shardconfig.rank * delta
# up_idx = down_idx + delta
# return tensor[:, down_idx:up_idx].contiguous()
49 changes: 35 additions & 14 deletions colossalai/shardformer/test/test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import random

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler

import colossalai
from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0

Expand All @@ -30,43 +32,48 @@ def load_data():
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")

train_dataset = tokenized_datasets["train"].select(range(500))
test_dataset = tokenized_datasets["test"].select(range(100))
train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]

datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader


def inference(model: nn.Module):
print(model)
def inference(model: nn.Module, args):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs)


def train(model: nn.Module, num_epoch: int = 2):
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
progress_bar = tqdm(range((num_epoch) * len(train_dataloader)))
criterion = nn.CrossEntropyLoss()
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
lr_scheduler = get_scheduler(name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training)
best_test_loss = float("inf")
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")

for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
# lr_scheduler.step()
progress_bar.update(1)
train_loss = loss

Expand All @@ -75,23 +82,37 @@ def train(model: nn.Module, num_epoch: int = 2):
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
assert not torch.isnan(outputs.loss), f"{batch}"
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if test_loss < best_test_loss:
best_test_loss = test_loss
torch.save(model.state_dict(), "./checkpoints/best_model.pth")


def dropout(model: nn.Module, args, input: torch.Tensor() = torch.randn(5, 4)):
input = input.to("cuda")
m = Dropout1D(0.3).to("cuda")
for i in range(2):
print(f"Output: {m(input)}")
print(torch.randn(1))


if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
colossalai.launch_from_torch(config=args.config)
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
sharded_model = shard_model(model, shard_config)

if args.mode == "train":
train(sharded_model)
train(sharded_model, args)
elif args.mode == "inference":
inference(sharded_model)
inference(sharded_model, args)
elif args.mode == 'dropout':
dropout(sharded_model, args)