Skip to content

Commit

Permalink
Add Inference test for llama (#4508)
Browse files Browse the repository at this point in the history
* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
  • Loading branch information
3 people authored Aug 30, 2023
1 parent 35af65d commit f0aab7f
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 138 deletions.
6 changes: 4 additions & 2 deletions colossalai/inference/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .modeling.llama import LlamaInferenceForwards
from .pollcies.llama import LlamaModelInferPolicy
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager

__all__ = ['MemoryManager', 'TPInferEngine']
__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine']
22 changes: 11 additions & 11 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM']
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']


class TPInferEngine:
Expand All @@ -27,7 +27,7 @@ def __init__(self,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: torch.device = torch.cuda.current_device()) -> None:
device: str = 'cuda') -> None:
self.model = model
self.sharded_model = None

Expand All @@ -40,7 +40,7 @@ def __init__(self,
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint"

self.device = device
torch.device(device=device)
self.dtype = dtype

self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
Expand Down Expand Up @@ -88,7 +88,7 @@ def shard_model_by(self, shardformer: ShardFormer) -> None:
assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(self.model, inference_only=True)
self.sharded_model, _ = shardformer.optimize(self.model, policy)
self.sharded_model = self.sharded_model.to(self.device)
self.sharded_model = self.sharded_model.cuda()

@staticmethod
def _supported_models() -> List[str]:
Expand Down Expand Up @@ -137,7 +137,7 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te
input_tokens = dict(input_ids=input_tokens)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(self.device)
input_tokens[t] = input_tokens[t].cuda()

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

Expand Down Expand Up @@ -173,8 +173,8 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
else:
batch_size = inputs.shape[0]

seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
start_index = 0

max_len_in_batch = -1
Expand All @@ -197,10 +197,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState:

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len),
dtype=torch.long,
device=self.device)
device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to(self.device)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
Expand Down Expand Up @@ -251,4 +251,4 @@ def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
# => put information already recorded in batchinferstate and pass it to model forward
# => clear records in engine
def add_request():
raise NotImplementedError()
raise NotImplementedError()
3 changes: 3 additions & 0 deletions colossalai/inference/tensor_parallel/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .llama import LlamaInferenceForwards

__all__ = ['LlamaInferenceForwards']
Loading

0 comments on commit f0aab7f

Please sign in to comment.