Skip to content

Commit

Permalink
Workspace for Inference Kernels (#1)
Browse files Browse the repository at this point in the history
* Add workspace capability to DSKernel

* Add to injection pipeline

* Validated
  • Loading branch information
cmikeh2 authored and JamesTheZ committed Dec 20, 2023
1 parent beed9c2 commit 1528732
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 2 deletions.
54 changes: 54 additions & 0 deletions deepspeed/inference/v2/kernels/ds_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
# DeepSpeed Team

from abc import ABC, abstractmethod
from typing import Optional

import torch

from deepspeed.inference.v2.logging import inference_logger


class DSKernelBase(ABC):

_workspace: Optional[torch.Tensor] = None

@abstractmethod
def __init__(self, *args, **kwargs):
"""
Expand All @@ -30,3 +37,50 @@ def __call__(self, *args, **kwargs):
should be performed here.
"""
raise NotImplementedError()

def requested_workspace_size(self) -> int:
"""
Return the requested workspace size in bytes.
This should be overloaded if the kernel requires a workspace.
Returns:
int: Number of bytes necessary.
"""
return 0

def get_workspace(self, bytes: int) -> torch.Tensor:
"""
Return the data pointer to the scratchpad memory.
Args:
bytes (int): Number of bytes necessary.
Raises:
RuntimeError: If the workspace is not allocated.
ValueError: If the workspace is not large enough.
"""
if DSKernelBase._workspace is None:
raise RuntimeError("Workspace not allocated")
if DSKernelBase._workspace.numel() < bytes:
raise ValueError("Workspace too small")
return DSKernelBase._workspace

@staticmethod
def create_workspace(bytes: int) -> int:
"""
Create a workspace of the requested size.
Args:
bytes (int): Number of bytes necessary.
Raises:
RuntimeError: If the workspace is already allocated.
"""
if DSKernelBase._workspace is not None:
raise RuntimeError("Workspace already allocated")

if bytes > 0:
inference_logger().info(f"Allocating {bytes} bytes of workspace")

DSKernelBase._workspace = torch.empty((bytes, ), dtype=torch.uint8)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch

import deepspeed.comm as dist
from deepspeed.inference.v2.kernels.ds_kernel import DSKernelBase
from deepspeed.inference.v2.modules.ds_module import DSModuleBase
from ..ragged import DSStateManager, RaggedBatchWrapper
from ..ragged.manager_configs import KVCacheConfig
from ..ragged import DSSequenceDescriptor
Expand Down Expand Up @@ -266,3 +268,19 @@ def forward(wrapped_batch: RaggedBatchWrapper) -> torch.Tensor:
should not rely on the ability to use python control flow.
"""
raise NotImplementedError()

def initialize_kernel_workspace(self) -> None:
"""
Iterates over all kernels in the model and collects the requested workspace size. This
workspace is then allocated and stored in the kernel base class.
"""
max_workspace_size = 0
for module in self.modules():
if not isinstance(module, DSModuleBase):
continue

for kernel in module.kernels():
max_workspace_size = max(max_workspace_size, kernel.requested_workspace_size())

if max_workspace_size > 0:
DSKernelBase.create_workspace(max_workspace_size)
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def build_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any)
DSInferenceModelBase: An implementation of the inference model abstraction that will be
run by the engine.
"""
self.model = self.instantiate_model(engine_config, mp_group)
self.model: DSInferenceModelBase = self.instantiate_model(engine_config, mp_group)
self.model.initialize_kernel_workspace()
self.populate_model_parameters()
return self.model

Expand Down
14 changes: 13 additions & 1 deletion deepspeed/inference/v2/modules/ds_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# DeepSpeed Team

from abc import ABC, abstractstaticmethod
from typing import Any, Dict, Type
from typing import Any, Dict, Iterable, Type

import torch

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.inference.v2.kernels import DSKernelBase


class DSModuleConfig(DeepSpeedConfigModel):
Expand Down Expand Up @@ -60,3 +61,14 @@ def __init__(self, config: DSModuleConfig, implementation_config: Dict[str, Any]
super().__init__()
self._config = config
self._implementation_config = implementation_config

def kernels(self) -> Iterable[DSKernelBase]:
"""
Return an iterable of all kernels used by this module.
This should be implemented by the children of functionality modules and should report
all kernels used by this module.
"""
for attr in dir(self):
if isinstance(getattr(self, attr), DSKernelBase):
yield getattr(self, attr)

0 comments on commit 1528732

Please sign in to comment.