Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gemini] implement auto policy prefetch and a little origin code modification. #5733

Merged
merged 12 commits into from
May 20, 2024
8 changes: 7 additions & 1 deletion colossalai/zero/gemini/gemini_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def pre_op(self, params):
self._gemini_manager.sample_overall_data()

# evit chunks, aware of async fetched
# TODO: check if prefetched chunks will be evicted
self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)
Expand All @@ -50,7 +51,12 @@ def pre_op(self, params):
self._chunk_manager.access_chunk(chunk)

# get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(
is_warmup=self._gemini_manager.is_warmup(),
compute_list=self._gemini_manager.compute_list,
compute_idx=self._gemini_manager.compute_idx,
async_works=self._gemini_manager.async_works,
)

# prefetch
for chunk in chunks_fetch_async:
Expand Down
6 changes: 5 additions & 1 deletion colossalai/zero/gemini/gemini_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.work] = {}
self._async_works: Dict[Chunk, dist.Work] = {}

self._h2d_volume = 0
self._d2h_volume = 0
Expand Down Expand Up @@ -183,6 +183,10 @@ def compute_list(self) -> List[Tuple[Chunk, ...]]:
def compute_idx(self) -> int:
return self._compute_idx

@property
def async_works(self) -> Dict[Chunk, dist.Work]:
return self._async_works

@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy
Expand Down
57 changes: 27 additions & 30 deletions colossalai/zero/gemini/placement_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Optional, Tuple, Type

import torch
import torch.distributed as dist

from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
Expand All @@ -19,13 +20,11 @@ class PlacementPolicy(ABC):

def __init__(
self,
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch
Expand All @@ -40,14 +39,15 @@ def setup_grads_device(
) -> None:
raise NotImplementedError

def get_prefetch_chunks(self) -> List[Chunk]:
def get_prefetch_chunks(
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
return [] # no prefetch by default


class StaticPlacementPolicy(PlacementPolicy):
def __init__(
self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
Expand All @@ -56,9 +56,7 @@ def __init__(
offload_param_frac: float = 0.0,
**kwargs,
) -> None:
super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0
Expand Down Expand Up @@ -109,21 +107,22 @@ def setup_grads_device(
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)

def get_prefetch_chunks(self) -> List[Chunk]:
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
def get_prefetch_chunks(
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return []
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
can_prefetch = self.max_prefetch - len(async_works)
prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
break_flag = False
for chunk in self.gemini_manager.compute_list[i]:
for i in range(compute_idx + 1, len(compute_list)):
for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch:
break_flag = True
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk)
if break_flag:
break
else:
continue
break
return prefetch


Expand All @@ -132,17 +131,14 @@ class AutoPlacementPolicy(PlacementPolicy):

def __init__(
self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs,
) -> None:
super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
Expand Down Expand Up @@ -233,8 +229,10 @@ def setup_grads_device(
else:
grads_device_map[p] = torch.device("cpu")

def get_prefetch_chunks(self) -> List[Chunk]:
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
def get_prefetch_chunks(
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
) -> List[Chunk]:
if is_warmup: # no prefetch during warmup since we need compute_list
return []
# modified from self.evict_tensors
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
Expand All @@ -246,19 +244,18 @@ def get_prefetch_chunks(self) -> List[Chunk]:
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data

prefetch_chunk_memory = 0
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
can_prefetch = self.max_prefetch - len(async_works)
prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
break_flag = False
for chunk in self.gemini_manager.compute_list[i]:
chunk: Chunk
for i in range(compute_idx + 1, len(compute_list)):
for chunk in compute_list[i]:
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
break_flag = True
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch_chunk_memory += chunk.chunk_mem
prefetch.append(chunk)
if break_flag:
break
else:
continue
break
return prefetch


Expand Down
2 changes: 1 addition & 1 deletion examples/language/gpt/gemini/run_gemini.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
export GPUNUM=${GPUNUM:-1}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10}
export TRAIN_STEP=${TRAIN_STEP:-2}
# export PYTHONPATH=$PWD:$PYTHONPATH


Expand Down
5 changes: 3 additions & 2 deletions examples/language/gpt/gemini/train_gpt_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,19 @@ def forward(self, logits, labels):


def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
return psutil.Process().memory_info().rss / 1024**2 # MB unit


def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
return torch.cuda.memory_allocated() / 1024**2 # MB unit


def get_mem_info(prefix=""):
return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"


def get_model_size(model: nn.Module):
# get the number of parameter of the model
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
Expand Down
Loading