-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(accelerator): support accelerator
- Loading branch information
1 parent
fbff756
commit 4e4c34e
Showing
35 changed files
with
691 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .abstract_accelerator import get_accelerator | ||
|
||
get_accelerator() | ||
from .abstract_accelerator import internlm_accelerator | ||
|
||
__all__ = [ | ||
"internlm_accelerator", | ||
"get_accelerator", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Universal accelerator interface implementation, inspired by DeepSpeed. | ||
""" | ||
import os | ||
|
||
internlm_accelerator = None | ||
|
||
|
||
class Accelerator: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
# Device APIs | ||
def device_name(self, device_index=None): | ||
raise NotImplementedError | ||
|
||
def device(self, device_index=None): | ||
raise NotImplementedError | ||
|
||
def set_device(self, device_index): | ||
raise NotImplementedError | ||
|
||
def current_device(self): | ||
raise NotImplementedError | ||
|
||
def current_device_name(self): | ||
raise NotImplementedError | ||
|
||
def device_count(self): | ||
raise NotImplementedError | ||
|
||
def synchronize(self, device_index=None): | ||
raise NotImplementedError | ||
|
||
|
||
def get_accelerator(): | ||
global internlm_accelerator | ||
if internlm_accelerator is not None: | ||
return internlm_accelerator | ||
|
||
accelerator_name = None | ||
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable. | ||
intern_accelerator_LIST = ["cuda", "npu"] | ||
if "INTERNLM_ACCELERATOR" in os.environ.keys(): | ||
accelerator_name = os.environ["INTERNLM_ACCELERATOR"] | ||
if accelerator_name == "npu": | ||
try: | ||
import torch_npu # noqa: F401 # type: ignore | ||
except ImportError as e: | ||
raise ValueError(f"NPU_Accelerator requires torch_npu, which is not installed on this system.") | ||
pass | ||
else: | ||
raise ValueError( | ||
f'internlm_accelerator must be one of {intern_accelerator_LIST}. Value "{accelerator_name}" is not supported' | ||
) | ||
|
||
# 2. If no override, detect which accelerator to use automatically | ||
if accelerator_name == None: | ||
if accelerator_name == None: | ||
try: | ||
import torch_npu # noqa: F401,F811 # type: ignore | ||
|
||
accelerator_name = "npu" | ||
except ImportError as e: | ||
pass | ||
if accelerator_name == None: | ||
accelerator_name = "cuda" | ||
|
||
# 3. Set internlm_accelerator accordingly | ||
if accelerator_name == "cuda": | ||
from .cuda_accelerator import CUDA_Accelerator | ||
|
||
internlm_accelerator = CUDA_Accelerator() | ||
elif accelerator_name == "npu": | ||
from .npu_accelerator import ASCEND_Accelerator | ||
|
||
internlm_accelerator = ASCEND_Accelerator() | ||
|
||
return internlm_accelerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from .abstract_accelerator import Accelerator | ||
|
||
try: | ||
import torch.cuda | ||
except ImportError: | ||
pass | ||
|
||
|
||
class CUDA_Accelerator(Accelerator): | ||
def __init__(self) -> None: | ||
self._name = "cuda" | ||
self._communication_backend_name = "nccl" | ||
|
||
# Device APIs | ||
def device_name(self, device_index=None): | ||
if device_index == None: | ||
return "cuda" | ||
return "cuda:{}".format(device_index) | ||
|
||
def device(self, device_index=None): | ||
return torch.cuda.device(device_index) | ||
|
||
def set_device(self, device_index): | ||
torch.cuda.set_device(device_index) | ||
|
||
def current_device(self): | ||
return torch.cuda.current_device() | ||
|
||
def current_device_name(self): | ||
return "cuda:{}".format(torch.cuda.current_device()) | ||
|
||
def device_count(self): | ||
return torch.cuda.device_count() | ||
|
||
def synchronize(self, device_index=None): | ||
return torch.cuda.synchronize(device_index) | ||
|
||
# RNG APIs | ||
def random(self): | ||
return torch.random | ||
|
||
def set_rng_state(self, new_state, device_index=None): | ||
if device_index is None: | ||
return torch.cuda.set_rng_state(new_state) | ||
|
||
return torch.cuda.set_rng_state(new_state, device_index) | ||
|
||
def get_rng_state(self, device_index=None): | ||
if device_index is None: | ||
return torch.cuda.get_rng_state() | ||
|
||
return torch.cuda.get_rng_state(device_index) | ||
|
||
def manual_seed(self, seed): | ||
return torch.cuda.manual_seed(seed) | ||
|
||
def manual_seed_all(self, seed): | ||
return torch.cuda.manual_seed_all(seed) | ||
|
||
def initial_seed(self, seed): | ||
return torch.cuda.initial_seed(seed) | ||
|
||
def default_generator(self, device_index): | ||
return torch.cuda.default_generators[device_index] | ||
|
||
# Streams/Events | ||
@property | ||
def Stream(self): | ||
return torch.cuda.Stream | ||
|
||
def stream(self, stream): | ||
return torch.cuda.stream(stream) | ||
|
||
def current_stream(self, device_index=None): | ||
return torch.cuda.current_stream(device_index) | ||
|
||
def default_stream(self, device_index=None): | ||
return torch.cuda.default_stream(device_index) | ||
|
||
@property | ||
def Event(self): | ||
return torch.cuda.Event | ||
|
||
# Memory management | ||
def empty_cache(self): | ||
return torch.cuda.empty_cache() | ||
|
||
def memory_allocated(self, device_index=None): | ||
return torch.cuda.memory_allocated(device_index) | ||
|
||
def max_memory_allocated(self, device_index=None): | ||
return torch.cuda.max_memory_allocated(device_index) | ||
|
||
def reset_max_memory_allocated(self, device_index=None): | ||
return torch.cuda.reset_max_memory_allocated(device_index) | ||
|
||
def memory_cached(self, device_index=None): | ||
return torch.cuda.memory_cached(device_index) | ||
|
||
def max_memory_cached(self, device_index=None): | ||
return torch.cuda.max_memory_cached(device_index) | ||
|
||
def reset_max_memory_cached(self, device_index=None): | ||
return torch.cuda.reset_max_memory_cached(device_index) | ||
|
||
def memory_stats(self, device_index=None): | ||
if hasattr(torch.cuda, "memory_stats"): | ||
return torch.cuda.memory_stats(device_index) | ||
|
||
def reset_peak_memory_stats(self, device_index=None): | ||
if hasattr(torch.cuda, "reset_peak_memory_stats"): | ||
return torch.cuda.reset_peak_memory_stats(device_index) | ||
|
||
def memory_reserved(self, device_index=None): | ||
if hasattr(torch.cuda, "memory_reserved"): | ||
return torch.cuda.memory_reserved(device_index) | ||
|
||
def max_memory_reserved(self, device_index=None): | ||
if hasattr(torch.cuda, "max_memory_reserved"): | ||
return torch.cuda.max_memory_reserved(device_index) | ||
|
||
def total_memory(self, device_index=None): | ||
return torch.cuda.get_device_properties(device_index).total_memory | ||
|
||
# Data types | ||
def is_bf16_supported(self): | ||
return torch.cuda.is_bf16_supported() | ||
|
||
def is_fp16_supported(self): | ||
major, _ = torch.cuda.get_device_capability() | ||
if major >= 7: | ||
return True | ||
else: | ||
return False | ||
|
||
# Misc | ||
def amp(self): | ||
if hasattr(torch.cuda, "amp"): | ||
return torch.cuda.amp | ||
return None | ||
|
||
def is_available(self): | ||
return torch.cuda.is_available() | ||
|
||
def range_push(self, msg): | ||
if hasattr(torch.cuda.nvtx, "range_push"): | ||
return torch.cuda.nvtx.range_push(msg) | ||
|
||
def range_pop(self): | ||
if hasattr(torch.cuda.nvtx, "range_pop"): | ||
return torch.cuda.nvtx.range_pop() | ||
|
||
def lazy_call(self, callback): | ||
return torch.cuda._lazy_call(callback) | ||
|
||
def communication_backend_name(self): | ||
return self._communication_backend_name | ||
|
||
# Tensor operations | ||
|
||
@property | ||
def BFloat16Tensor(self): | ||
return torch.cuda.BFloat16Tensor | ||
|
||
@property | ||
def ByteTensor(self): | ||
return torch.cuda.ByteTensor | ||
|
||
@property | ||
def DoubleTensor(self): | ||
return torch.cuda.DoubleTensor | ||
|
||
@property | ||
def FloatTensor(self): | ||
return torch.cuda.FloatTensor | ||
|
||
@property | ||
def HalfTensor(self): | ||
return torch.cuda.HalfTensor | ||
|
||
@property | ||
def IntTensor(self): | ||
return torch.cuda.IntTensor | ||
|
||
@property | ||
def LongTensor(self): | ||
return torch.cuda.LongTensor | ||
|
||
def pin_memory(self, tensor): | ||
return tensor.pin_memory() | ||
|
||
def on_accelerator(self, tensor): | ||
device_str = str(tensor.device) | ||
if device_str.startswith("cuda:"): | ||
return True | ||
else: | ||
return False | ||
|
||
def set_allow_tf32(self, enable: bool): | ||
torch.backends.cudnn.allow_tf32 = enable | ||
torch.backends.cuda.matmul.allow_tf32 = enable | ||
|
||
def return_custom_bwd(self): | ||
return torch.cuda.amp.custom_bwd |
Oops, something went wrong.