Skip to content

Commit

Permalink
feat(accelerator): support accelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Jan 30, 2024
1 parent fbff756 commit 4e4c34e
Show file tree
Hide file tree
Showing 35 changed files with 691 additions and 127 deletions.
1 change: 1 addition & 0 deletions internlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .accelerator import get_accelerator # Trigger accelerator initialization
from .initialize.initialize_trainer import initialize_trainer
from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch

Expand Down
9 changes: 9 additions & 0 deletions internlm/accelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .abstract_accelerator import get_accelerator

get_accelerator()
from .abstract_accelerator import internlm_accelerator

__all__ = [
"internlm_accelerator",
"get_accelerator",
]
79 changes: 79 additions & 0 deletions internlm/accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Universal accelerator interface implementation, inspired by DeepSpeed.
"""
import os

internlm_accelerator = None


class Accelerator:
def __init__(self) -> None:
pass

# Device APIs
def device_name(self, device_index=None):
raise NotImplementedError

def device(self, device_index=None):
raise NotImplementedError

def set_device(self, device_index):
raise NotImplementedError

def current_device(self):
raise NotImplementedError

def current_device_name(self):
raise NotImplementedError

def device_count(self):
raise NotImplementedError

def synchronize(self, device_index=None):
raise NotImplementedError


def get_accelerator():
global internlm_accelerator
if internlm_accelerator is not None:
return internlm_accelerator

accelerator_name = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
intern_accelerator_LIST = ["cuda", "npu"]
if "INTERNLM_ACCELERATOR" in os.environ.keys():
accelerator_name = os.environ["INTERNLM_ACCELERATOR"]
if accelerator_name == "npu":
try:
import torch_npu # noqa: F401 # type: ignore
except ImportError as e:
raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.")
pass
else:
raise ValueError(
f'internlm_accelerator must be one of {intern_accelerator_LIST}. Value "{accelerator_name}" is not supported'
)

# 2. If no override, detect which accelerator to use automatically
if accelerator_name == None:
if accelerator_name == None:
try:
import torch_npu # noqa: F401,F811 # type: ignore

accelerator_name = "npu"
except ImportError as e:
pass
if accelerator_name == None:
accelerator_name = "cuda"

# 3. Set internlm_accelerator accordingly
if accelerator_name == "cuda":
from .cuda_accelerator import CUDA_Accelerator

internlm_accelerator = CUDA_Accelerator()
elif accelerator_name == "npu":
from .npu_accelerator import ASCEND_Accelerator

internlm_accelerator = ASCEND_Accelerator()

return internlm_accelerator
204 changes: 204 additions & 0 deletions internlm/accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from .abstract_accelerator import Accelerator

try:
import torch.cuda
except ImportError:
pass


class CUDA_Accelerator(Accelerator):
def __init__(self) -> None:
self._name = "cuda"
self._communication_backend_name = "nccl"

# Device APIs
def device_name(self, device_index=None):
if device_index == None:
return "cuda"
return "cuda:{}".format(device_index)

def device(self, device_index=None):
return torch.cuda.device(device_index)

def set_device(self, device_index):
torch.cuda.set_device(device_index)

def current_device(self):
return torch.cuda.current_device()

def current_device_name(self):
return "cuda:{}".format(torch.cuda.current_device())

def device_count(self):
return torch.cuda.device_count()

def synchronize(self, device_index=None):
return torch.cuda.synchronize(device_index)

# RNG APIs
def random(self):
return torch.random

def set_rng_state(self, new_state, device_index=None):
if device_index is None:
return torch.cuda.set_rng_state(new_state)

return torch.cuda.set_rng_state(new_state, device_index)

def get_rng_state(self, device_index=None):
if device_index is None:
return torch.cuda.get_rng_state()

return torch.cuda.get_rng_state(device_index)

def manual_seed(self, seed):
return torch.cuda.manual_seed(seed)

def manual_seed_all(self, seed):
return torch.cuda.manual_seed_all(seed)

def initial_seed(self, seed):
return torch.cuda.initial_seed(seed)

def default_generator(self, device_index):
return torch.cuda.default_generators[device_index]

# Streams/Events
@property
def Stream(self):
return torch.cuda.Stream

def stream(self, stream):
return torch.cuda.stream(stream)

def current_stream(self, device_index=None):
return torch.cuda.current_stream(device_index)

def default_stream(self, device_index=None):
return torch.cuda.default_stream(device_index)

@property
def Event(self):
return torch.cuda.Event

# Memory management
def empty_cache(self):
return torch.cuda.empty_cache()

def memory_allocated(self, device_index=None):
return torch.cuda.memory_allocated(device_index)

def max_memory_allocated(self, device_index=None):
return torch.cuda.max_memory_allocated(device_index)

def reset_max_memory_allocated(self, device_index=None):
return torch.cuda.reset_max_memory_allocated(device_index)

def memory_cached(self, device_index=None):
return torch.cuda.memory_cached(device_index)

def max_memory_cached(self, device_index=None):
return torch.cuda.max_memory_cached(device_index)

def reset_max_memory_cached(self, device_index=None):
return torch.cuda.reset_max_memory_cached(device_index)

def memory_stats(self, device_index=None):
if hasattr(torch.cuda, "memory_stats"):
return torch.cuda.memory_stats(device_index)

def reset_peak_memory_stats(self, device_index=None):
if hasattr(torch.cuda, "reset_peak_memory_stats"):
return torch.cuda.reset_peak_memory_stats(device_index)

def memory_reserved(self, device_index=None):
if hasattr(torch.cuda, "memory_reserved"):
return torch.cuda.memory_reserved(device_index)

def max_memory_reserved(self, device_index=None):
if hasattr(torch.cuda, "max_memory_reserved"):
return torch.cuda.max_memory_reserved(device_index)

def total_memory(self, device_index=None):
return torch.cuda.get_device_properties(device_index).total_memory

# Data types
def is_bf16_supported(self):
return torch.cuda.is_bf16_supported()

def is_fp16_supported(self):
major, _ = torch.cuda.get_device_capability()
if major >= 7:
return True
else:
return False

# Misc
def amp(self):
if hasattr(torch.cuda, "amp"):
return torch.cuda.amp
return None

def is_available(self):
return torch.cuda.is_available()

def range_push(self, msg):
if hasattr(torch.cuda.nvtx, "range_push"):
return torch.cuda.nvtx.range_push(msg)

def range_pop(self):
if hasattr(torch.cuda.nvtx, "range_pop"):
return torch.cuda.nvtx.range_pop()

def lazy_call(self, callback):
return torch.cuda._lazy_call(callback)

def communication_backend_name(self):
return self._communication_backend_name

# Tensor operations

@property
def BFloat16Tensor(self):
return torch.cuda.BFloat16Tensor

@property
def ByteTensor(self):
return torch.cuda.ByteTensor

@property
def DoubleTensor(self):
return torch.cuda.DoubleTensor

@property
def FloatTensor(self):
return torch.cuda.FloatTensor

@property
def HalfTensor(self):
return torch.cuda.HalfTensor

@property
def IntTensor(self):
return torch.cuda.IntTensor

@property
def LongTensor(self):
return torch.cuda.LongTensor

def pin_memory(self, tensor):
return tensor.pin_memory()

def on_accelerator(self, tensor):
device_str = str(tensor.device)
if device_str.startswith("cuda:"):
return True
else:
return False

def set_allow_tf32(self, enable: bool):
torch.backends.cudnn.allow_tf32 = enable
torch.backends.cuda.matmul.allow_tf32 = enable

def return_custom_bwd(self):
return torch.cuda.amp.custom_bwd
Loading

0 comments on commit 4e4c34e

Please sign in to comment.