Skip to content

Commit

Permalink
revise inference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanheng-zhao committed Sep 11, 2023
1 parent 45e5d91 commit bf9d026
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 66 deletions.
2 changes: 0 additions & 2 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoTokenizer, BloomForCausalLM

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
Expand Down
82 changes: 20 additions & 62 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

TP_SIZE = 2
Expand All @@ -22,31 +22,25 @@
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_prepare_data():
# dummy module used for testing
class DummyModule(nn.Module):

def __init__(self, config):
super(DummyModule, self).__init__()
self.config = config

def forward(self, x):
return x

# dummy config used for testing
class DummyModelConfig:

def __init__(self):
self.hidden_size = 4096
self.num_attention_heads = 32
self.num_hidden_layers = 8
@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

dummy_config = DummyModelConfig()
model = DummyModule(dummy_config)
shard_config = ShardConfig(enable_tensor_parallelism=False)
# 1. check TPInferEngine init and model optimization
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE

# 2. check data preparation
input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
[80540, 15473, 3331, 11970], [80540, 15473]]
batch_size = len(input_ids_list)
Expand Down Expand Up @@ -74,48 +68,14 @@ def __init__(self):
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_generate():
# 3. check optimized model generate
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))

model_config = LlamaConfig()
model = LlamaForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=False)

# init TPInferEngine
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

# original model generate
generate_kwargs = dict(do_sample=False)
infer_engine.generate(input_ids, **generate_kwargs)

torch.cuda.empty_cache()


@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_config = BloomConfig()
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE

torch.cuda.empty_cache()


def check_engine(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Expand All @@ -126,11 +86,9 @@ def check_engine(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_engine_tp():
def test_engine():
spawn(check_engine, TP_SIZE)


if __name__ == '__main__':
test_prepare_data()
test_generate()
test_engine_tp()
test_engine()
2 changes: 0 additions & 2 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import LlamaForCausalLM, LlamaTokenizer

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
Expand Down

0 comments on commit bf9d026

Please sign in to comment.