Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] Add triton layer norm & replace norm for bloom #4609

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions colossalai/inference/tensor_parallel/policies/bloom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
from functools import partial

import torch
from torch.nn import LayerNorm

from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy

from ..modeling.bloom import BloomInferenceForwards

try:
from colossalai.kernel.triton.fused_layernorm import layer_norm
HAS_TRITON_NORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_NORM = False


def get_triton_layernorm_forward():
if HAS_TRITON_NORM:

def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)

return _triton_layernorm_forward
else:
return None


class BloomModelInferPolicy(BloomForCausalLMPolicy):

Expand All @@ -14,31 +37,30 @@ def module_policy(self):
# NOTE set inference mode to shard config
self.shard_config._infer()

if self.shard_config.enable_tensor_parallelism:
method_replacement = {
'forward': BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomForCausalLM)

method_replacement = {
'forward':
BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation':
BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomForCausalLM)
method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)

method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomModel)
method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)

method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomBlock)
method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomAttention)

method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
if HAS_TRITON_NORM:
infer_method = get_triton_layernorm_forward()
method_replacement = {'forward': partial(infer_method)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomAttention)
target_key=LayerNorm)

return policy
5 changes: 3 additions & 2 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd
from .softmax import softmax
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .rms_norm import rmsnorm_forward
from .softmax import softmax
83 changes: 83 additions & 0 deletions colossalai/kernel/triton/fused_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch

try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

if HAS_TRITON:
# CREDITS: These functions are adapted from the Triton tutorial
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html

@triton.jit
def _layer_norm_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
# Compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)

@torch.no_grad()
def layer_norm(x, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M,)](x_arg,
y,
weight,
bias,
x_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y
10 changes: 6 additions & 4 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os

import pytest
from packaging import version
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM

import colossalai
Expand All @@ -20,10 +21,10 @@


def run():
model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b"
model_path = "/data3/models/bloom-7b1"
if os.path.isdir(model_path) is False:
return
return

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

Expand Down Expand Up @@ -54,6 +55,7 @@ def check_engine(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
Expand Down
64 changes: 64 additions & 0 deletions tests/test_infer_ops/triton/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
import torch
from packaging import version

from colossalai.kernel.triton import layer_norm
from colossalai.testing.utils import parameterize
from tests.test_infer_ops.triton.utils import benchmark

try:
import triton
import triton.language as tl

from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')


@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
@parameterize('M', [2, 4, 8, 16])
@parameterize('N', [64, 128])
def test_layer_norm(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.rand(w_shape, dtype=dtype, device='cuda')
bias = torch.rand(w_shape, dtype=dtype, device='cuda')
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')

y_triton = layer_norm(x, weight, bias, eps)
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)

assert y_triton.shape == y_torch.shape
assert y_triton.dtype == y_torch.dtype
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)


@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
@parameterize('M', [4])
@parameterize('N', [128])
def test_benchmark(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.rand(w_shape, dtype=dtype, device='cuda')
bias = torch.rand(w_shape, dtype=dtype, device='cuda')
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')

latency_1 = benchmark(layer_norm, x, weight, bias, eps)
latency_2 = benchmark(torch.nn.functional.layer_norm, x, w_shape, weight, bias, eps)
print("the triton op latency is {} ms".format(str(latency_1)))
print("the torch op latency is {} ms".format(str(latency_2)))


if __name__ == "__main__":
test_layer_norm()
Loading