Skip to content

Commit

Permalink
[pipeline] refactor 1f1b schedule (hpcaitech#4115)
Browse files Browse the repository at this point in the history
* [api] update optimizer wrapper to fit pipeline

* [pipeline] add base schedule

* [pipeline] add 1f1b schedule

* [test] add pipeline schedule utils test

* [pipeline] fix import
  • Loading branch information
ver217 committed Aug 15, 2023
1 parent a636ea4 commit aa690f9
Show file tree
Hide file tree
Showing 6 changed files with 451 additions and 0 deletions.
4 changes: 4 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -53,6 +54,9 @@ def backward(self, loss: Tensor, *args, **kwargs):
"""
loss.backward(*args, **kwargs)

def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)

def state_dict(self):
"""
Returns the optimizer state.
Expand Down
7 changes: 7 additions & 0 deletions colossalai/pipeline/schedule/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base import PipelineSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule

__all__ = [
'PipelineSchedule',
'OneForwardOneBackwardSchedule',
]
129 changes: 129 additions & 0 deletions colossalai/pipeline/schedule/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, List, Optional

import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten


def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
"""Move object to device if it is a tensor.
Args:
x (Any): Object to be moved.
device (Optional[torch.device], optional): Target device. Defaults to None.
Returns:
Any: Moved object.
"""
if isinstance(x, torch.Tensor):
return x.to(device)
return x


def get_batch_size(batch: Any) -> int:
"""Get the batch size (size of dimension-0) of the first tensor in the batch.
Args:
batch (Any): Batch to be inspected.
Raises:
RuntimeError: If no tensor is found in the batch.
Returns:
int: Batch size.
"""
data_list, _ = tree_flatten(batch)
for data in data_list:
if isinstance(data, torch.Tensor):
return data.size(0)
raise RuntimeError('No tensor found in the batch')


def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any:
"""Get a micro batch of the original batch.
Args:
batch (Any): Batch to be sliced.
start (int): Start index of the micro batch.
micro_batch_size (int): Size of the micro batch.
Returns:
Any: Target micro batch.
"""

def _get_tensor_slice(x: Any):
if isinstance(x, torch.Tensor):
return x[start:start + micro_batch_size]
return x

return tree_map(_get_tensor_slice, batch)


def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any:
"""Call model forward function with data and internal inputs.
Args:
model (Module): Model to be called.
data (Any): Data loaded from data iterator.
internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage.
Returns:
Any: Outputs of the model.
"""
if internal_inputs is None:
internal_inputs = {}
if isinstance(data, (list, tuple)):
return model(*data, **internal_inputs)
elif isinstance(data, dict):
return model(**data, **internal_inputs)
return model(data, **internal_inputs)


def retain_grad(x: Any) -> None:
"""Call retain_grad() on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor):
x.retain_grad()


def detach(x: Any) -> Any:
"""Call detach() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The detached object.
"""
if isinstance(x, torch.Tensor):
return x.detach()
return x


def merge_batch(data: List[Any]) -> Any:
"""Merge micro batches into a batch.
Args:
data (List[Any]): A list of micro batches.
Returns:
Any: Merge batch.
"""
if len(data) == 0:
return
flattened_data = []
tree_spec = None
for d in data:
elems, tree_spec = tree_flatten(d)
flattened_data.append(elems)
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)
35 changes: 35 additions & 0 deletions colossalai/pipeline/schedule/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Callable, Iterable

from torch import Tensor
from torch.nn import Module

from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.stage_manager import PipelineStageManager


class PipelineSchedule:

def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager

def forward_backward_step(self,
model: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[[Any, Any], Tensor],
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Forward and backward step for pipeline training.
Args:
model (Module): Model to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
raise NotImplementedError
Loading

0 comments on commit aa690f9

Please sign in to comment.