-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[kernel] Add triton layer norm & replace norm for bloom (#4609)
* add layernorm for inference * add test for layernorm kernel * add bloom layernorm replacement policy * trivial: path
- Loading branch information
1 parent
da77c97
commit b34e44e
Showing
5 changed files
with
198 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd | ||
from .softmax import softmax | ||
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd | ||
from .copy_kv_cache_dest import copy_kv_cache_to_dest | ||
from .fused_layernorm import layer_norm | ||
from .rms_norm import rmsnorm_forward | ||
from .softmax import softmax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
|
||
try: | ||
import triton | ||
import triton.language as tl | ||
HAS_TRITON = True | ||
except ImportError: | ||
HAS_TRITON = False | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
if HAS_TRITON: | ||
# CREDITS: These functions are adapted from the Triton tutorial | ||
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html | ||
|
||
@triton.jit | ||
def _layer_norm_fwd_fused( | ||
X, # pointer to the input | ||
Y, # pointer to the output | ||
W, # pointer to the weights | ||
B, # pointer to the biases | ||
stride, # how much to increase the pointer when moving by 1 row | ||
N, # number of columns in X | ||
eps, # epsilon to avoid division by zero | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
# Map the program id to the row of X and Y it should compute. | ||
row = tl.program_id(0) | ||
Y += row * stride | ||
X += row * stride | ||
# Compute mean | ||
mean = 0 | ||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) | ||
_mean += a | ||
mean = tl.sum(_mean, axis=0) / N | ||
# Compute variance | ||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) | ||
x = tl.where(cols < N, x - mean, 0.) | ||
_var += x * x | ||
var = tl.sum(_var, axis=0) / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
# Normalize and apply linear transformation | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
w = tl.load(W + cols, mask=mask) | ||
b = tl.load(B + cols, mask=mask) | ||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) | ||
x_hat = (x - mean) * rstd | ||
y = x_hat * w + b | ||
# Write output | ||
tl.store(Y + cols, y.to(tl.float16), mask=mask) | ||
|
||
@torch.no_grad() | ||
def layer_norm(x, weight, bias, eps): | ||
# allocate output | ||
y = torch.empty_like(x) | ||
# reshape input data into 2D tensor | ||
x_arg = x.reshape(-1, x.shape[-1]) | ||
M, N = x_arg.shape | ||
# Less than 64KB per feature: enqueue fused kernel | ||
MAX_FUSED_SIZE = 65536 // x.element_size() | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | ||
if N > BLOCK_SIZE: | ||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | ||
# heuristics for number of warps | ||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8) | ||
# enqueue kernel | ||
_layer_norm_fwd_fused[(M,)](x_arg, | ||
y, | ||
weight, | ||
bias, | ||
x_arg.stride(0), | ||
N, | ||
eps, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
num_warps=num_warps) | ||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
import torch | ||
from packaging import version | ||
|
||
from colossalai.kernel.triton import layer_norm | ||
from colossalai.testing.utils import parameterize | ||
from tests.test_infer_ops.triton.utils import benchmark | ||
|
||
try: | ||
import triton | ||
import triton.language as tl | ||
|
||
from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused | ||
HAS_TRITON = True | ||
except ImportError: | ||
HAS_TRITON = False | ||
print("please install triton from https://github.com/openai/triton") | ||
|
||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') | ||
|
||
|
||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, | ||
reason="triton requires cuda version to be higher than 11.4") | ||
@parameterize('M', [2, 4, 8, 16]) | ||
@parameterize('N', [64, 128]) | ||
def test_layer_norm(M, N): | ||
dtype = torch.float16 | ||
eps = 1e-5 | ||
x_shape = (M, N) | ||
w_shape = (x_shape[-1],) | ||
weight = torch.rand(w_shape, dtype=dtype, device='cuda') | ||
bias = torch.rand(w_shape, dtype=dtype, device='cuda') | ||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') | ||
|
||
y_triton = layer_norm(x, weight, bias, eps) | ||
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) | ||
|
||
assert y_triton.shape == y_torch.shape | ||
assert y_triton.dtype == y_torch.dtype | ||
print("max delta: ", torch.max(torch.abs(y_triton - y_torch))) | ||
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) | ||
|
||
|
||
@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, | ||
reason="triton requires cuda version to be higher than 11.4") | ||
@parameterize('M', [4]) | ||
@parameterize('N', [128]) | ||
def test_benchmark(M, N): | ||
dtype = torch.float16 | ||
eps = 1e-5 | ||
x_shape = (M, N) | ||
w_shape = (x_shape[-1],) | ||
weight = torch.rand(w_shape, dtype=dtype, device='cuda') | ||
bias = torch.rand(w_shape, dtype=dtype, device='cuda') | ||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') | ||
|
||
latency_1 = benchmark(layer_norm, x, weight, bias, eps) | ||
latency_2 = benchmark(torch.nn.functional.layer_norm, x, w_shape, weight, bias, eps) | ||
print("the triton op latency is {} ms".format(str(latency_1))) | ||
print("the torch op latency is {} ms".format(str(latency_2))) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_layer_norm() |