Skip to content

Commit

Permalink
fix mx triton kernel after PyTorch triton pin change (#431)
Browse files Browse the repository at this point in the history
Summary:

Triton pin updated recently:
pytorch/pytorch#126098

In the new triton version, functions can only access global variables of
type `tl.constexpr`. Due to the current structure of the code and the
fact that these constants are also used by non-triton programs, I think
the best thing to do is to just stop using globals in the MX triton
kernel. The PR lifts all of these constants to kernel function
arguments.

Test Plan:

```
pytest test/prototype/mx_formats/test_custom_cast.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Jun 24, 2024
1 parent 37c348e commit c2cf973
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.5.0.dev20240620+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand Down
143 changes: 127 additions & 16 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,19 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor):
import triton.language as tl

@triton.jit
def _fp4_packed_to_bf16(x_packed):
def _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
):
"""
Input: a tensor of packed fp4 values
Output: a tensor of bfloat16 values
Expand All @@ -123,7 +135,7 @@ def _fp4_packed_to_bf16(x_packed):
# output = x_unpacked.to(tl.float32)

# save the sign
sign_f4 = x & SIGN_MASK_F4
sign_f4 = x & sign_mask_f4

# set everything to positive, will add sign back at the end
x_pos = x ^ sign_f4
Expand All @@ -138,25 +150,25 @@ def _fp4_packed_to_bf16(x_packed):
denormal_mask = x_pos == 1

# calculate the new exponent and shift it to bits 2:9 of the result
exp_biased_f4 = x_pos >> MBITS_F4_E2M1
exp_biased_f32 = exp_biased_f4 - F4_E2M1_EXP_BIAS + F32_EXP_BIAS
exp_biased_f32 = exp_biased_f32.to(tl.int32) << MBITS_F32
exp_biased_f4 = x_pos >> mbits_f4_e2m1
exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias
exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32

# shift the mantissa to bits 10:32 of the result
mantissa_f4 = x_pos & MANTISSA_MASK_F4
mantissa_f32 = mantissa_f4.to(tl.int32) << (MBITS_F32 - MBITS_F4_E2M1)
mantissa_f4 = x_pos & mantissa_mask_f4
mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1)
output = mantissa_f32

# combine the pieces
result = exp_biased_f32 | mantissa_f32
# result[zero_mask] = ZERO_BITS_F32
result = tl.where(zero_mask, ZERO_BITS_F32, result)
result = tl.where(zero_mask, zero_bits_f32, result)
# result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32
result = tl.where(denormal_mask, ZERO_POINT_FIVE_BITS_F32, result)
result = tl.where(denormal_mask, zero_point_five_bits_f32, result)

# add sign back
sign_f32 = sign_f4.to(tl.int32) << (
MBITS_F32 - MBITS_F4_E2M1 + EBITS_F32 - EBITS_F4_E2M1
mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1
)
result = result | sign_f32

Expand All @@ -174,6 +186,16 @@ def triton_f4_to_bf16_kernel(
x_ptr,
output_ptr,
n_elements_in,
sign_mask_f4: tl.constexpr,
mantissa_mask_f4: tl.constexpr,
mbits_f4_e2m1: tl.constexpr,
ebits_f4_e2m1: tl.constexpr,
f4_e2m1_exp_bias: tl.constexpr,
mbits_f32: tl.constexpr,
ebits_f32: tl.constexpr,
f32_exp_bias: tl.constexpr,
zero_bits_f32: tl.constexpr,
zero_point_five_bits_f32: tl.constexpr,
BLOCK_SIZE_IN: tl.constexpr,
):
pid = tl.program_id(axis=0)
Expand All @@ -187,7 +209,19 @@ def triton_f4_to_bf16_kernel(

# packed uint8
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
output = _fp4_packed_to_bf16(x_packed)
output = _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
)

# set up output offsets
block_start_out = pid * BLOCK_SIZE_OUT
Expand All @@ -213,6 +247,18 @@ def triton_f4_to_scaled_bf16_kernel(
output_ptr,
n_elements_in,
mx_block_size: tl.constexpr,
sign_mask_f4: tl.constexpr,
mantissa_mask_f4: tl.constexpr,
mbits_f4_e2m1: tl.constexpr,
ebits_f4_e2m1: tl.constexpr,
f4_e2m1_exp_bias: tl.constexpr,
mbits_f32: tl.constexpr,
ebits_f32: tl.constexpr,
f32_exp_bias: tl.constexpr,
zero_bits_f32: tl.constexpr,
zero_point_five_bits_f32: tl.constexpr,
e8m0_exponent_bias: tl.constexpr,
e8m0_exponent_nan_val: tl.constexpr,
BLOCK_SIZE_IN: tl.constexpr,
):
pid = tl.program_id(axis=0)
Expand All @@ -227,7 +273,19 @@ def triton_f4_to_scaled_bf16_kernel(
mask_in = offsets_in < n_elements_in
# packed uint8
x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)
output = _fp4_packed_to_bf16(x_packed)
output = _fp4_packed_to_bf16(
x_packed,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
)

# load scale
block_start_s = pid * BLOCK_SIZE_S
Expand All @@ -236,9 +294,9 @@ def triton_f4_to_scaled_bf16_kernel(
s = tl.load(s_ptr + offsets_s, mask=mask_s)

# create the scale in bf16
s_offset = s.to(tl.int16) - E8M0_EXPONENT_BIAS
s_offset = s.to(tl.int16) - e8m0_exponent_bias
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
s_fp = tl.where(s != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))

# multiply output by scale
# TODO(later): see if manipulating the exponent instead of fp
Expand All @@ -263,6 +321,16 @@ def triton_f4_to_bf16_kernel(
x_ptr,
output_ptr,
n_elements_in,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
BLOCK_SIZE_IN,
):
raise AssertionError("unsupported without triton")
Expand All @@ -273,6 +341,18 @@ def triton_f4_to_scaled_bf16_kernel(
output_ptr,
n_elements_in,
mx_block_size,
sign_mask_f4,
mantissa_mask_f4,
mbits_f4_e2m1,
ebits_f4_e2m1,
f4_e2m1_exp_bias,
mbits_f32,
ebits_f32,
f32_exp_bias,
zero_bits_f32,
zero_point_five_bits_f32,
e8m0_exponent_bias,
e8m0_exponent_nan_val,
BLOCK_SIZE_IN,
):
raise AssertionError("unsupported without triton")
Expand All @@ -294,7 +374,22 @@ def triton_f4_to_bf16(x: torch.Tensor):
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
) # noqa: E731,E501
triton_f4_to_bf16_kernel[grid](x, output, n_elements_in, BLOCK_SIZE_IN=512)
triton_f4_to_bf16_kernel[grid](
x,
output,
n_elements_in,
sign_mask_f4=SIGN_MASK_F4,
mantissa_mask_f4=MANTISSA_MASK_F4,
mbits_f4_e2m1=MBITS_F4_E2M1,
ebits_f4_e2m1=EBITS_F4_E2M1,
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
mbits_f32=MBITS_F32,
ebits_f32=EBITS_F32,
f32_exp_bias=F32_EXP_BIAS,
zero_bits_f32=ZERO_BITS_F32,
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
BLOCK_SIZE_IN=512,
)
return output


Expand All @@ -318,7 +413,23 @@ def triton_f4_to_scaled_bf16(
triton.cdiv(n_elements_in, meta["BLOCK_SIZE_IN"]),
)
triton_f4_to_scaled_bf16_kernel[grid](
x, s_e8m0, output, n_elements_in, mx_block_size
x,
s_e8m0,
output,
n_elements_in,
mx_block_size,
sign_mask_f4=SIGN_MASK_F4,
mantissa_mask_f4=MANTISSA_MASK_F4,
mbits_f4_e2m1=MBITS_F4_E2M1,
ebits_f4_e2m1=EBITS_F4_E2M1,
f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,
mbits_f32=MBITS_F32,
ebits_f32=EBITS_F32,
f32_exp_bias=F32_EXP_BIAS,
zero_bits_f32=ZERO_BITS_F32,
zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,
e8m0_exponent_bias=E8M0_EXPONENT_BIAS,
e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL,
)
return output

Expand Down

0 comments on commit c2cf973

Please sign in to comment.