|
1 | 1 | import torch
|
2 | 2 | import torch.nn as nn
|
| 3 | +from transformers.tokenization_utils_base import BatchEncoding |
3 | 4 |
|
4 | 5 | from colossalai.cluster import ProcessGroupMesh
|
5 | 6 | from colossalai.pipeline.schedule.generate import GenerateSchedule
|
6 | 7 | from colossalai.pipeline.stage_manager import PipelineStageManager
|
7 | 8 | from colossalai.shardformer import ShardConfig, ShardFormer
|
8 | 9 | from colossalai.shardformer.policies.base_policy import Policy
|
9 | 10 |
|
10 |
| -from .kvcache_manager import MemoryManager |
| 11 | +from ..tensor_parallel.kvcache_manager import MemoryManager |
11 | 12 | from .microbatch_manager import MicroBatchManager
|
12 | 13 |
|
13 | 14 |
|
@@ -38,7 +39,7 @@ class PPInferEngine:
|
38 | 39 |
|
39 | 40 | colossalai.launch_from_torch(config={})
|
40 | 41 |
|
41 |
| - model = LlamaForCausalLM.from_pretrained("/home/lczyh/share/models/llama-7b-hf") |
| 42 | + model = LlamaForCausalLM.from_pretrained("your_path_to_model") |
42 | 43 | tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
43 | 44 | # assume the model is infered with 2 pipeline stages
|
44 | 45 | inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8)
|
@@ -103,7 +104,20 @@ def __init__(
|
103 | 104 | self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
104 | 105 |
|
105 | 106 | def inference(self, input_list):
|
106 |
| - out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) |
| 107 | + """ |
| 108 | + Args: |
| 109 | + input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + out (list): a list of output data, each element is a list of token. |
| 113 | + timestamp (float): the time cost of the inference, only return when verbose is `True`. |
| 114 | + """ |
| 115 | + assert isinstance( |
| 116 | + input_list, (BatchEncoding, dict) |
| 117 | + ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." |
| 118 | + if isinstance(input_list, BatchEncoding): |
| 119 | + input_list = input_list.data |
| 120 | + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) |
107 | 121 | if self.verbose:
|
108 | 122 | return out, timestamp
|
109 | 123 | else:
|
|
0 commit comments