forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pipeline] refactor 1f1b schedule (hpcaitech#4115)
* [api] update optimizer wrapper to fit pipeline * [pipeline] add base schedule * [pipeline] add 1f1b schedule * [test] add pipeline schedule utils test * [pipeline] fix import
- Loading branch information
Showing
6 changed files
with
451 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .base import PipelineSchedule | ||
from .one_f_one_b import OneForwardOneBackwardSchedule | ||
|
||
__all__ = [ | ||
'PipelineSchedule', | ||
'OneForwardOneBackwardSchedule', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from typing import Any, List, Optional | ||
|
||
import torch | ||
import torch.cuda | ||
from torch.nn import Module | ||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten | ||
|
||
|
||
def to_device(x: Any, device: Optional[torch.device] = None) -> Any: | ||
"""Move object to device if it is a tensor. | ||
Args: | ||
x (Any): Object to be moved. | ||
device (Optional[torch.device], optional): Target device. Defaults to None. | ||
Returns: | ||
Any: Moved object. | ||
""" | ||
if isinstance(x, torch.Tensor): | ||
return x.to(device) | ||
return x | ||
|
||
|
||
def get_batch_size(batch: Any) -> int: | ||
"""Get the batch size (size of dimension-0) of the first tensor in the batch. | ||
Args: | ||
batch (Any): Batch to be inspected. | ||
Raises: | ||
RuntimeError: If no tensor is found in the batch. | ||
Returns: | ||
int: Batch size. | ||
""" | ||
data_list, _ = tree_flatten(batch) | ||
for data in data_list: | ||
if isinstance(data, torch.Tensor): | ||
return data.size(0) | ||
raise RuntimeError('No tensor found in the batch') | ||
|
||
|
||
def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: | ||
"""Get a micro batch of the original batch. | ||
Args: | ||
batch (Any): Batch to be sliced. | ||
start (int): Start index of the micro batch. | ||
micro_batch_size (int): Size of the micro batch. | ||
Returns: | ||
Any: Target micro batch. | ||
""" | ||
|
||
def _get_tensor_slice(x: Any): | ||
if isinstance(x, torch.Tensor): | ||
return x[start:start + micro_batch_size] | ||
return x | ||
|
||
return tree_map(_get_tensor_slice, batch) | ||
|
||
|
||
def model_forward(model: Module, data: Any, internal_inputs: Optional[dict]) -> Any: | ||
"""Call model forward function with data and internal inputs. | ||
Args: | ||
model (Module): Model to be called. | ||
data (Any): Data loaded from data iterator. | ||
internal_inputs (Optional[dict]): Data from previous stage. It must be a dict or None if it's the first stage. | ||
Returns: | ||
Any: Outputs of the model. | ||
""" | ||
if internal_inputs is None: | ||
internal_inputs = {} | ||
if isinstance(data, (list, tuple)): | ||
return model(*data, **internal_inputs) | ||
elif isinstance(data, dict): | ||
return model(**data, **internal_inputs) | ||
return model(data, **internal_inputs) | ||
|
||
|
||
def retain_grad(x: Any) -> None: | ||
"""Call retain_grad() on a tensor. | ||
Args: | ||
x (Any): Object to be called. | ||
""" | ||
if isinstance(x, torch.Tensor): | ||
x.retain_grad() | ||
|
||
|
||
def detach(x: Any) -> Any: | ||
"""Call detach() on a tensor. | ||
Args: | ||
x (Any): Object to be called. | ||
Returns: | ||
Any: The detached object. | ||
""" | ||
if isinstance(x, torch.Tensor): | ||
return x.detach() | ||
return x | ||
|
||
|
||
def merge_batch(data: List[Any]) -> Any: | ||
"""Merge micro batches into a batch. | ||
Args: | ||
data (List[Any]): A list of micro batches. | ||
Returns: | ||
Any: Merge batch. | ||
""" | ||
if len(data) == 0: | ||
return | ||
flattened_data = [] | ||
tree_spec = None | ||
for d in data: | ||
elems, tree_spec = tree_flatten(d) | ||
flattened_data.append(elems) | ||
merged_data = [] | ||
for elem_batch in zip(*flattened_data): | ||
if isinstance(elem_batch[0], torch.Tensor): | ||
merged_data.append(torch.cat(elem_batch, dim=0)) | ||
else: | ||
merged_data.append(list(elem_batch)) | ||
return tree_unflatten(merged_data, tree_spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Any, Callable, Iterable | ||
|
||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from colossalai.interface import OptimizerWrapper | ||
from colossalai.pipeline.stage_manager import PipelineStageManager | ||
|
||
|
||
class PipelineSchedule: | ||
|
||
def __init__(self, stage_manager: PipelineStageManager) -> None: | ||
self.stage_manager = stage_manager | ||
|
||
def forward_backward_step(self, | ||
model: Module, | ||
optimizer: OptimizerWrapper, | ||
data_iter: Iterable, | ||
criterion: Callable[[Any, Any], Tensor], | ||
return_loss: bool = False, | ||
return_outputs: bool = False) -> dict: | ||
"""Forward and backward step for pipeline training. | ||
Args: | ||
model (Module): Model to be trained. | ||
optimizer (OptimizerWrapper): Optimizer to be used. | ||
data_iter (Iterable): Data iterator. | ||
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. | ||
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. | ||
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. | ||
Returns: | ||
dict: A dict with keys: 'loss' and 'outputs'. | ||
""" | ||
raise NotImplementedError |
Oops, something went wrong.