Skip to content

Commit

Permalink
add annotation for topology, backend, optimizers and utility
Browse files Browse the repository at this point in the history
  • Loading branch information
cccvs committed Apr 12, 2024
1 parent c05b24f commit e7ca2f8
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 137 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
- name: Install BlueFogLite
run: |
conda install pip
python --version
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install .
Expand Down Expand Up @@ -100,6 +101,7 @@ jobs:
- name: Install BlueFogLite
run: |
conda install pip
python --version
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
python -m pip install .
Expand Down
190 changes: 97 additions & 93 deletions bluefoglite/common/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from enum import Enum
from collections import Counter
import itertools
import warnings
from enum import Enum
from collections import Counter
from contextlib import contextmanager
from typing import (
Any,
Iterator,
Optional,
Tuple,
Union,
List,
Callable,
Iterable,
Dict,
Set,
)

import torch
from torch.nn import Module, Parameter
from torch.optim import Optimizer

import bluefoglite.torch_api as bfl


Expand All @@ -14,26 +29,9 @@ class CommunicationType(Enum):
empty = "empty"


_warning_message_num_step_per_communication = (
"Unexpected behavior:\n"
" After num_steps_per_communication times of forward computation `y=model(x)` are called,\n"
" an optimizer step() function must be called.\n"
" It does not matter how many step() functions are called in between.\n"
" Please adjust num_step_per_communication to update model parameters locally.\n"
" More information can be found in the FAQ page.\n"
)
_warning_message_backward_pass_per_step = (
"Unexpected behavior:\n"
" After num_steps_per_communication times of backward"
" computation `loss.backward()` are called,\n"
" an optimizer step() function must be called.\n"
" It does not matter how many step() functions are called in between.\n"
" Please adjust num_steps_per_communication to accumulate gradients locally.\n"
" More information can be found in the FAQ page.\n"
)


def _named_leaf_module(module, parent_name=None):
def _named_leaf_module(
module: Module, parent_name: Optional[str] = None
) -> Iterator[Tuple[Optional[str], Module]]:
"""Yield an iterator over all leaf modules."""
if not list(module.named_children()):
yield (parent_name, module)
Expand All @@ -42,13 +40,15 @@ def _named_leaf_module(module, parent_name=None):
yield from _named_leaf_module(ch_module, full_name)


def _check_named_parameters(optimizer, model):
def _check_named_parameters(
optimizer: Optimizer, model: Union[Module, List[Module]]
) -> Tuple[list, list]:
_models = None
if isinstance(model, torch.nn.Module):
if isinstance(model, Module):
_models = [model]
if isinstance(model, list):
for m in model:
assert isinstance(m, torch.nn.Module)
assert isinstance(m, Module)
_models = model
assert _models is not None
named_parameters = list(itertools.chain(*[m.named_parameters() for m in _models]))
Expand All @@ -74,7 +74,7 @@ def _check_named_parameters(optimizer, model):
all_param_ids = {
id(v) for param_group in optimizer.param_groups for v in param_group["params"]
}
named_param_ids = {id(v) for k, v in named_parameters}
named_param_ids = {id(v) for _, v in named_parameters}
unnamed_param_ids = all_param_ids - named_param_ids
if unnamed_param_ids:
raise ValueError(
Expand All @@ -88,46 +88,53 @@ def _check_named_parameters(optimizer, model):
# pylint: disable=too-many-instance-attributes
class _DistributedReduceOptimizer(torch.optim.Optimizer):
def __init__(
self, params, model, communication_type, num_steps_per_communication=1
):
self,
params: Iterable[Parameter],
model: Union[Module, List[Module]],
communication_type: CommunicationType,
num_steps_per_communication: int = 1,
) -> None:
# pylint: disable=bad-super-call, no-value-for-parameter
super(self.__class__, self).__init__(params)
super(self.__class__, self).__init__(params) # type: ignore

named_parameters, models = _check_named_parameters(self, model)
# knobs for neighbor communication behavior
self.self_weight = None
self.src_weights = None
self.dst_weights = None
self.src_machine_weights = None
self.dst_machine_weights = None
self.enable_topo_check = False

self._models = models
self._parameter_names = {v: k for k, v in sorted(named_parameters)}
self._name_parameters = dict(sorted(named_parameters))
self._async_works = {}
self._requires_update = set()
self._synchronized = False
self._should_synchronize = True
self._error_encountered = False
self._num_steps_per_communication = num_steps_per_communication
self.self_weight: Optional[float] = None
self.src_weights: Optional[Dict[int, float]] = None
self.dst_weights: Optional[Dict[int, float]] = None
self.src_machine_weights: Optional[Dict[int, float]] = None
self.dst_machine_weights: Optional[Dict[int, float]] = None
self.enable_topo_check: bool = False

self._models: List[Module] = models
self._parameter_names: Dict[Parameter, str] = {
v: k for k, v in sorted(named_parameters)
}
self._name_parameters: Dict[str, Parameter] = dict(sorted(named_parameters))
self._async_works: Dict[Parameter, bfl.AsyncWork] = {}
self._requires_update: Set[Parameter] = set()
self._synchronized: bool = False
self._should_synchronize: bool = True
self._error_encountered: bool = False
self._num_steps_per_communication: int = num_steps_per_communication
assert isinstance(communication_type, CommunicationType)
self._communication_type = communication_type
self._communication_type: CommunicationType = communication_type

if bfl.size() > 1:
self._register_hooks()

def _register_hooks(self):
def _register_hooks(self) -> None:
for model in self._models:
# The hook is added at model level instead of layer level, as it avoids triggering
# the hook function of the same layer multiple times in case the layer is called
# several times during the forward computation of the model.
model.register_forward_hook(self._make_hook())
self._requires_update.update(dict(model.named_parameters()).values())

def _make_hook(self):
def hook(model, *unused):
def _make_hook(self) -> Any:
def hook(model: Module, *unused: Tuple[Any, ...]):
for parent_name, layer in _named_leaf_module(model):
assert parent_name is not None
for name, p in layer.named_parameters():
if not layer.training:
continue
Expand All @@ -153,7 +160,7 @@ def hook(model, *unused):

return hook

def _neighbor_allreduce_data_async(self, p):
def _neighbor_allreduce_data_async(self, p: Parameter) -> bfl.AsyncWork:
async_work = bfl.neighbor_allreduce_nonblocking(
p.data,
self_weight=self.self_weight,
Expand All @@ -163,20 +170,20 @@ def _neighbor_allreduce_data_async(self, p):
)
return async_work

def _allreduce_data_async(self, p):
def _allreduce_data_async(self, p: Parameter) -> bfl.AsyncWork:
async_work = bfl.allreduce_nonblocking(p.data, inplace=True)
return async_work

@property
def communication_type(self):
def communication_type(self) -> CommunicationType:
return self._communication_type

@communication_type.setter
def communication_type(self, value):
def communication_type(self, value: CommunicationType) -> None:
assert isinstance(value, CommunicationType)
self._communication_type = value

def synchronize(self):
def synchronize(self) -> None:
with torch.no_grad():
for _, async_work in self._async_works.items():
if async_work is not None:
Expand All @@ -185,7 +192,7 @@ def synchronize(self):
self._synchronized = True

@contextmanager
def skip_synchronize(self):
def skip_synchronize(self) -> Iterator[None]:
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
Expand Down Expand Up @@ -224,25 +231,32 @@ def step(self, closure=None):

# pylint: disable=too-many-instance-attributes
class _DistributedOptimizer(torch.optim.Optimizer):
def __init__(self, params, model, backward_passes_per_step=1):
def __init__(
self,
params: Iterable[Parameter],
model: Union[Module, List[Module]],
backward_passes_per_step: int = 1,
) -> None:
# pylint: disable=bad-super-call, no-value-for-parameter
super(self.__class__, self).__init__(params)
super(self.__class__, self).__init__(params) # type: ignore

named_parameters, models = _check_named_parameters(self, model)
self._models = models
self._parameter_names = {v: k for k, v in sorted(named_parameters)}
self._async_works = {}
self._grad_accs = []
self._requires_update = set()
self._synchronized = False
self._should_synchronize = True
self._backward_passes_per_step = backward_passes_per_step
self._error_encountered = False
self._models: List[Module] = models
self._parameter_names: Dict[Parameter, str] = {
v: k for k, v in sorted(named_parameters)
}
self._async_works: Dict[Parameter, bfl.AsyncWork] = {}
self._grad_accs: List[Any] = []
self._requires_update: Set[Parameter] = set()
self._synchronized: bool = False
self._should_synchronize: bool = True
self._backward_passes_per_step: int = backward_passes_per_step
self._error_encountered: bool = False

if bfl.size() > 1:
self._register_hooks()

def _register_hooks(self):
def _register_hooks(self) -> None:
for param_group in self.param_groups:
for p in param_group["params"]:
if p.requires_grad:
Expand All @@ -253,39 +267,29 @@ def _register_hooks(self):
grad_acc.register_hook(self._make_hook(p))
self._grad_accs.append(grad_acc)

def _make_hook(self, p):
def hook(*ignore):
def _make_hook(self, p: Parameter) -> Callable[..., None]:
def hook(*ignore: Any) -> None:
assert p.grad is not None
assert not p.grad.requires_grad
async_work = self._allreduce_grad_async(p)
self._async_works[p] = async_work

return hook

def _allreduce_grad_async(self, p):
def _allreduce_grad_async(self, p: Parameter) -> bfl.AsyncWork:
assert p.grad is not None
async_work = bfl.allreduce_nonblocking(p.grad, inplace=True)
return async_work

def synchronize(self):
def synchronize(self) -> None:
with torch.no_grad():
for _, async_work in self._async_works.items():
async_work.wait()
self._async_works.clear()
self._synchronized = True

@contextmanager
def skip_synchronize(self):
"""
A context manager used to specify that optimizer.step() should
not perform synchronization.
It's typically used in a following pattern:
.. code-block:: python
optimizer.synchronize()
with optimizer.skip_synchronize():
optimizer.step()
"""
def skip_synchronize(self) -> Iterator[None]:
self._should_synchronize = False
try:
yield
Expand All @@ -310,11 +314,11 @@ def step(self, closure=None):


def DistributedAdaptWithCombineOptimizer(
optimizer,
model,
communication_type=CommunicationType.neighbor_allreduce,
num_steps_per_communication=1,
):
optimizer: Optimizer,
model: Union[Module, List[Module]],
communication_type: CommunicationType = CommunicationType.neighbor_allreduce,
num_steps_per_communication: int = 1,
) -> Any:
cls = type(
optimizer.__class__.__name__,
(optimizer.__class__,),
Expand All @@ -326,10 +330,10 @@ def DistributedAdaptWithCombineOptimizer(


def DistributedGradientAllreduceOptimizer(
optimizer,
model,
num_steps_per_communication=1,
):
optimizer: Optimizer,
model: Union[Module, List[Module]],
num_steps_per_communication: int = 1,
) -> Any:
cls = type(
optimizer.__class__.__name__,
(optimizer.__class__,),
Expand Down
Loading

0 comments on commit e7ca2f8

Please sign in to comment.