Skip to content

Commit

Permalink
[Feature] The first PR to Add TP inference engine, kv-cache manager a…
Browse files Browse the repository at this point in the history
…nd related kernels for our inference system (#4577)

* [infer] Infer/llama demo (#4503)

* add

* add infer example

* finish

* finish

* stash

* fix

* [Kernels]  add inference token attention kernel (#4505)

* add token forward

* fix tests

* fix comments

* add try import triton

* add adapted license

* add tests check

* [Kernels] add necessary kernels (llama & bloom) for attention forward and kv-cache manager  (#4485)

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added

* combine codes (#4509)

* [feature] add KV cache manager for llama & bloom inference (#4495)

* add kv cache memory manager

* add stateinfo during inference

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* file dir change

* [Bug FIx] import llama context ops fix (#4524)

* added _vllm_rms_norm

* change place

* added tests

* added tests

* modify

* adding kernels

* added tests:

* adding kernels

* modify

* added

* updating kernels

* adding tests

* added tests

* kernel change

* submit

* modify

* added

* edit comments

* change name

* change commnets and fix import

* add

* added

* fix

* add ops into init.py

* add

* [Infer] Add TPInferEngine and fix file path (#4532)

* add engine for TP inference

* move file path

* update path

* fix TPInferEngine

* remove unused file

* add engine test demo

* revise TPInferEngine

* fix TPInferEngine, add test

* fix

* Add Inference test for llama (#4508)

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [infer] Add Bloom inference policy and replaced methods (#4512)

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix

* Revert "[infer] Add Bloom inference policy and replaced methods (#4512)" (#4552)

This reverts commit 17cfa57.

* [Doc] Add colossal inference doc (#4549)

* create readme

* add readme.md

* fix typos

* [infer] Add Bloom inference policy and replaced methods (#4553)

* add bloom inference methods and policy

* enable pass BatchInferState from model forward

* revise bloom infer layers/policies

* add engine for inference (draft)

* add test for bloom infer

* fix bloom infer policy and flow

* revise bloom test

* fix bloom file path

* remove unused codes

* fix bloom modeling

* fix dir typo

* fix trivial

* fix policy

* clean pr

* trivial fix

* trivial

* Fix Bugs In Llama Model Forward (#4550)

* add kv cache memory manager

* add stateinfo during inference

* add

* add infer example

* finish

* finish

* format

* format

* rename file

* add kv cache test

* revise on BatchInferState

* add inference test for llama

* fix conflict

* feature: add some new features for llama engine

* adapt colossalai triton interface

* Change the parent class of llama  policy

* add nvtx

* move llama inference code to tensor_parallel

* fix __init__.py

* rm tensor_parallel

* fix: fix bugs in auto_policy.py

* fix:rm some unused codes

* mv colossalai/tpinference to colossalai/inference/tensor_parallel

* change __init__.py

* save change

* fix engine

* Bug fix: Fix hang

* remove llama_infer_engine.py

* bug fix: fix bugs about infer_state.is_context_stage

* remove pollcies

* fix: delete unused code

* fix: delete unused code

* remove unused coda

* fix conflict

---------

Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [doc] add colossal inference fig (#4554)

* create readme

* add readme.md

* fix typos

* upload fig

* [NFC] fix docstring for colossal inference (#4555)

Fix docstring and comments in kv cache manager and bloom modeling

* fix docstring in llama modeling (#4557)

* [Infer] check import vllm (#4559)

* change import vllm

* import apply_rotary_pos_emb

* change import location

* [DOC] add installation req (#4561)

* add installation req

* fix

* slight change

* remove empty

* [Feature] rms-norm transfer into inference llama.py  (#4563)

* add installation req

* fix

* slight change

* remove empty

* add rmsnorm polciy

* add

* clean codes

* [infer] Fix tp inference engine (#4564)

* fix engine prepare data

* add engine test

* use bloom for testing

* revise on test

* revise on test

* reset shardformer llama (#4569)

* [infer] Fix engine - tensors on different devices (#4570)


* fix diff device in engine

* [codefactor] Feature/colossal inference (#4579)

* code factors

* remove

* change coding (#4581)

* [doc] complete README of colossal inference (#4585)

* complete fig

* Update README.md

* [doc]update readme (#4586)

* update readme

* Update README.md

* bug fix: fix bus in llama and bloom (#4588)

* [BUG FIX]Fix test engine in CI and non-vllm kernels llama forward  (#4592)

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* [Kernel]Rmsnorm fix (#4598)

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* add triton rmsnorm

* delete vllm kernel flag

* [Bug Fix]Fix bugs in llama (#4601)

* fix tests

* clean

* clean

* fix bugs

* add

* fix llama non-vllm kernels bug

* modify

* clean codes

* bug fix: remove rotary_positions_ids

---------

Co-authored-by: cuiqing.li <lixx3527@gmail.com>

* [kernel] Add triton layer norm & replace norm for bloom (#4609)

* add layernorm for inference

* add test for layernorm kernel

* add bloom layernorm replacement policy

* trivial: path

* [Infer] Bug fix rotary embedding in llama (#4608)

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* [bench] Add bloom inference benchmark (#4621)

* add bloom benchmark

* readme - update benchmark res

* trivial - uncomment for testing (#4622)

* [Infer] add check triton and cuda version for tests (#4627)

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* add check triton and cuda

* Update sharder.py (#4629)

* [Inference] Hot fix some bugs and typos (#4632)

* fix

* fix test

* fix conflicts

* [typo]Comments fix (#4633)

* fallback

* fix commnets

* bug fix: fix some bugs in test_llama and test_bloom (#4635)

* [Infer] delete benchmark in tests and fix bug for llama and bloom (#4636)

* fix rotary embedding

* delete print

* fix init seq len bug

* rename pytest

* add benchmark for llama

* refactor codes

* delete useless code

* add check triton and cuda

* delete benchmark and fix infer bugs

* delete benchmark for tests

* delete useless code

* delete bechmark function in utils

* [Fix] Revise TPInferEngine, inference tests and benchmarks (#4642)

* [Fix] revise TPInferEngine methods and inference tests

* fix llama/bloom infer benchmarks

* fix infer tests

* trivial fix: benchmakrs

* trivial

* trivial: rm print

* modify utils filename for infer ops test (#4657)

* [Infer] Fix TPInferEngine init & inference tests, benchmarks (#4670)

* fix engine funcs

* TPInferEngine: receive shard config in init

* benchmarks: revise TPInferEngine init

* benchmarks: remove pytest decorator

* trivial fix

* use small model for tests

* [NFC] use args for infer benchmarks (#4674)

* revise infer default (#4683)

* [Fix] optimize/shard model in TPInferEngine init (#4684)

* remove using orig model in engine

* revise inference tests

* trivial: rename

---------

Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: yuanheng-zhao <jonathan.zhaoyh@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
  • Loading branch information
7 people authored Sep 11, 2023
1 parent eedaa3e commit bce0f16
Show file tree
Hide file tree
Showing 49 changed files with 3,980 additions and 137 deletions.
32 changes: 32 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,35 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

---------------- LICENSE FOR VLLM TEAM ----------------

from VLLM TEAM:

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/vllm-project/vllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

---------------- LICENSE FOR LIGHTLLM TEAM ----------------

from LIGHTLLM TEAM:

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/ModelTC/lightllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
117 changes: 117 additions & 0 deletions colossalai/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# 🚀 Colossal-Inference

## Table of contents

## Introduction

`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.

## Design

Colossal Inference is composed of two main components:

1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.

## Pipeline of inference:

In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.

![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Colossal-inference.png)

## Roadmap of our implementation

- [x] Design cache manager and batch infer state
- [x] Design TpInference engine to integrates with `Shardformer`
- [x] Register corresponding high-performance `kernel` and `ops`
- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
- [x] policy
- [x] context forward
- [x] token forward
- [ ] Replace the kernels with `faster-transformer` in token-forward stage
- [ ] Support all models
- [x] Llama
- [x] Bloom
- [ ] Chatglm2
- [ ] Benchmarking for all models

## Get started

### Installation

```bash
pip install -e .
```

### Requirements

dependencies

```bash
pytorch= 1.13.1 (gpu)
cuda>= 11.6
transformers= 4.30.2
triton==2.0.0.dev20221202
# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch
vllm
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
flash-attention
```

### Docker

You can use docker run to use docker container to set-up environment

```
# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
```

### Dive into fast-inference!

example files are in

```bash
cd colossalai.examples
python xx
```

## Performance

### environment:

We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.

For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):

### Single GPU Performance:

Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned.

#### Llama

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
| colossal-inference | 326.4 | 582.72 | 816.64 |

![llama](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-llama7b.png)

### Bloom

| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
| colossal-inference | 323.28 | 538.52 | 611.64 |

![bloom](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/Infer-bloom7b.png)

The results of more models are coming soon!
Empty file.
4 changes: 4 additions & 0 deletions colossalai/inference/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager

__all__ = ['MemoryManager', 'TPInferEngine']
55 changes: 55 additions & 0 deletions colossalai/inference/tensor_parallel/batch_infer_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass
from typing import Any

import torch

from .kvcache_manager import MemoryManager


@dataclass
class BatchInferState:
r"""
Information to be passed and used for a batch of inputs during
a single model forward
"""
batch_size: int
max_len_in_batch: int

cache_manager: MemoryManager = None

block_loc: torch.Tensor = None
start_loc: torch.Tensor = None
seq_len: torch.Tensor = None
past_key_values_len: int = None

is_context_stage: bool = False
context_mem_index: torch.Tensor = None
decode_is_contiguous: bool = None
decode_mem_start: int = None
decode_mem_end: int = None
decode_mem_index: torch.Tensor = None
decode_layer_id: int = None

device: torch.device = torch.device('cuda')

@property
def total_token_num(self):
# return self.batch_size * self.max_len_in_batch
assert self.seq_len is not None and self.seq_len.size(0) > 0
return int(torch.sum(self.seq_len))

def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager

@staticmethod
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
alloc_mem_index: torch.Tensor):
""" in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0
seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
cur_seq_len]
start_index += cur_seq_len
return
Loading

0 comments on commit bce0f16

Please sign in to comment.