Skip to content

Commit 64c1f4f

Browse files
committed
refactor the code
1 parent bd00085 commit 64c1f4f

File tree

9 files changed

+103
-277
lines changed

9 files changed

+103
-277
lines changed

colossalai/inference/pipeline/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInfer
4747

4848
input = ["Introduce a landmark in China ","Introduce a landmark in China "]
4949
data = tokenizer(input, return_tensors='pt')
50-
output = inferengine.inference([data.to('cuda').data])
50+
output = inferengine.inference(data.to('cuda'))
5151

5252

5353
```

colossalai/inference/pipeline/batch_infer_state.py

-120
This file was deleted.

colossalai/inference/pipeline/engine.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import torch
22
import torch.nn as nn
3+
from transformers.tokenization_utils_base import BatchEncoding
34

45
from colossalai.cluster import ProcessGroupMesh
56
from colossalai.pipeline.schedule.generate import GenerateSchedule
67
from colossalai.pipeline.stage_manager import PipelineStageManager
78
from colossalai.shardformer import ShardConfig, ShardFormer
89
from colossalai.shardformer.policies.base_policy import Policy
910

10-
from .kvcache_manager import MemoryManager
11+
from ..tensor_parallel.kvcache_manager import MemoryManager
1112
from .microbatch_manager import MicroBatchManager
1213

1314

@@ -38,7 +39,7 @@ class PPInferEngine:
3839
3940
colossalai.launch_from_torch(config={})
4041
41-
model = LlamaForCausalLM.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
42+
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
4243
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
4344
# assume the model is infered with 2 pipeline stages
4445
inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
@@ -103,7 +104,20 @@ def __init__(
103104
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
104105

105106
def inference(self, input_list):
106-
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
107+
"""
108+
Args:
109+
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
110+
111+
Returns:
112+
out (list): a list of output data, each element is a list of token.
113+
timestamp (float): the time cost of the inference, only return when verbose is `True`.
114+
"""
115+
assert isinstance(
116+
input_list, (BatchEncoding, dict)
117+
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
118+
if isinstance(input_list, BatchEncoding):
119+
input_list = input_list.data
120+
out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
107121
if self.verbose:
108122
return out, timestamp
109123
else:

colossalai/inference/pipeline/kvcache_manager.py

-104
This file was deleted.

colossalai/inference/pipeline/microbatch_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55

6-
from .batch_infer_state import BatchInferState
7-
from .kvcache_manager import MemoryManager
6+
from ..tensor_parallel.batch_infer_state import BatchInferState
7+
from ..tensor_parallel.kvcache_manager import MemoryManager
88

99
__all__ = "MicroBatchManager"
1010

colossalai/inference/pipeline/modeling/llama.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from transformers.utils import logging
1313

1414
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
15-
from colossalai.kernel.triton import llama_context_attn_fwd, rotary_embedding_fwd, token_attention_fwd
15+
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
1616
from colossalai.pipeline.stage_manager import PipelineStageManager
1717

1818
from ._utils import copy_kv_to_mem_cache
@@ -31,6 +31,14 @@
3131
)
3232
HAS_VLLM_KERNERL = False
3333

34+
try:
35+
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
36+
37+
HAS_LIGHTLLM_KERNEL = True
38+
except:
39+
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
40+
HAS_LIGHTLLM_KERNEL = False
41+
3442

3543
def rotate_half(x):
3644
"""Rotates half the hidden dims of the input."""
@@ -363,8 +371,8 @@ def llama_flash_attn_kvcache_forward(
363371

364372
cos, sin = infer_state.position_cos, infer_state.position_sin
365373

366-
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
367-
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
374+
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
375+
llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
368376

369377
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
370378
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)

colossalai/inference/pipeline/utils.py

-35
This file was deleted.

0 commit comments

Comments
 (0)