From b34e44eea3ecdf1374d3666ddb19c933f18c99aa Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:38:50 +0800 Subject: [PATCH] [kernel] Add triton layer norm & replace norm for bloom (#4609) * add layernorm for inference * add test for layernorm kernel * add bloom layernorm replacement policy * trivial: path --- .../tensor_parallel/policies/bloom.py | 62 +++++++++----- colossalai/kernel/triton/__init__.py | 5 +- colossalai/kernel/triton/fused_layernorm.py | 83 +++++++++++++++++++ tests/test_infer/test_bloom_infer.py | 10 ++- tests/test_infer_ops/triton/test_layernorm.py | 64 ++++++++++++++ 5 files changed, 198 insertions(+), 26 deletions(-) create mode 100644 colossalai/kernel/triton/fused_layernorm.py create mode 100644 tests/test_infer_ops/triton/test_layernorm.py diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index d9dc2982d040..63791fe27284 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -1,7 +1,30 @@ +from functools import partial + +import torch +from torch.nn import LayerNorm + from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy from ..modeling.bloom import BloomInferenceForwards +try: + from colossalai.kernel.triton.fused_layernorm import layer_norm + HAS_TRITON_NORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_NORM = False + + +def get_triton_layernorm_forward(): + if HAS_TRITON_NORM: + + def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): + return layer_norm(hidden_states, self.weight.data, self.bias, self.eps) + + return _triton_layernorm_forward + else: + return None + class BloomModelInferPolicy(BloomForCausalLMPolicy): @@ -14,31 +37,30 @@ def module_policy(self): # NOTE set inference mode to shard config self.shard_config._infer() - if self.shard_config.enable_tensor_parallelism: + method_replacement = { + 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward, + 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomForCausalLM) - method_replacement = { - 'forward': - BloomInferenceForwards.bloom_for_causal_lm_forward, - 'prepare_inputs_for_generation': - BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomForCausalLM) + method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) - method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomModel) + method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) - method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomBlock) + method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomAttention) - method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} + if HAS_TRITON_NORM: + infer_method = get_triton_layernorm_forward() + method_replacement = {'forward': partial(infer_method)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, - target_key=BloomAttention) + target_key=LayerNorm) return policy diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 75bd4ed80a72..eb0335c01ce2 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,4 +1,5 @@ -from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd -from .softmax import softmax +from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest +from .fused_layernorm import layer_norm from .rms_norm import rmsnorm_forward +from .softmax import softmax diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py new file mode 100644 index 000000000000..99800acfbb92 --- /dev/null +++ b/colossalai/kernel/triton/fused_layernorm.py @@ -0,0 +1,83 @@ +import torch + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + # CREDITS: These functions are adapted from the Triton tutorial + # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + @triton.jit + def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + @torch.no_grad() + def layer_norm(x, weight, bias, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + _layer_norm_fwd_fused[(M,)](x_arg, + y, + weight, + bias, + x_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + return y diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 754c158e6279..dad3f9cb295f 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,8 +1,9 @@ import os + import pytest -from packaging import version import torch import torch.distributed as dist +from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM import colossalai @@ -20,10 +21,10 @@ def run(): - model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" + model_path = "/data3/models/bloom-7b1" if os.path.isdir(model_path) is False: - return - + return + tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token @@ -54,6 +55,7 @@ def check_engine(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run() + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() diff --git a/tests/test_infer_ops/triton/test_layernorm.py b/tests/test_infer_ops/triton/test_layernorm.py new file mode 100644 index 000000000000..15d0fe74c1ed --- /dev/null +++ b/tests/test_infer_ops/triton/test_layernorm.py @@ -0,0 +1,64 @@ +import pytest +import torch +from packaging import version + +from colossalai.kernel.triton import layer_norm +from colossalai.testing.utils import parameterize +from tests.test_infer_ops.triton.utils import benchmark + +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") +@parameterize('M', [2, 4, 8, 16]) +@parameterize('N', [64, 128]) +def test_layer_norm(M, N): + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device='cuda') + bias = torch.rand(w_shape, dtype=dtype, device='cuda') + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + + y_triton = layer_norm(x, weight, bias, eps) + y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) + + assert y_triton.shape == y_torch.shape + assert y_triton.dtype == y_torch.dtype + print("max delta: ", torch.max(torch.abs(y_triton - y_torch))) + assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, + reason="triton requires cuda version to be higher than 11.4") +@parameterize('M', [4]) +@parameterize('N', [128]) +def test_benchmark(M, N): + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device='cuda') + bias = torch.rand(w_shape, dtype=dtype, device='cuda') + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + + latency_1 = benchmark(layer_norm, x, weight, bias, eps) + latency_2 = benchmark(torch.nn.functional.layer_norm, x, w_shape, weight, bias, eps) + print("the triton op latency is {} ms".format(str(latency_1))) + print("the torch op latency is {} ms".format(str(latency_2))) + + +if __name__ == "__main__": + test_layer_norm()