Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] cherry pick fp8 attn nonsense with hack cream #907

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def main():
if "tensor_parallelism_size" in dataset.properties
else args.tensor_parallelism_size
)

llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=tensor_parallelism_size,
Expand Down
101 changes: 101 additions & 0 deletions sharktank/sharktank/kernels/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,110 @@

__all__ = [
"flash_attention",
"masked_flash_attention",
]


@CustomOp.register(library=LIBRARY)
class masked_flash_attention(CustomOp):

signature = "masked_flash_attention(Tensor q, Tensor k, Tensor v, Tensor? a, Tensor scale) -> (Tensor)"

def select(self, ksel: KernelSelection):
q_desc = ksel.arg_tensor(0) # Shape b, l, d
k_desc = ksel.arg_tensor(1) # Shape b, s, d
v_desc = ksel.arg_tensor(2) # Shape b, s, e
a_desc = ksel.arg_tensor(3) # Shape b, l, s
s_desc = ksel.arg_tensor(4)

q_bs = q_desc.t.shape[:-2]
k_bs = k_desc.t.shape[:-2]
v_bs = v_desc.t.shape[:-2]
a_bs = a_desc.t.shape[:-2]

bs = len(q_bs)

# Note: kernel does collapse dims to get to a single batch/head dim
torch._check(len(q_bs) == 2, lambda: f"TODO: batch dims {bs} not supported")

q_l, q_d = q_desc.t.shape[-2:]
k_s, k_d = k_desc.t.shape[-2:]
v_s, v_e = v_desc.t.shape[-2:]

torch._check(
q_desc.t.dtype.is_floating_point
and k_desc.t.dtype.is_floating_point
and v_desc.t.dtype.is_floating_point
and s_desc.t.dtype.is_floating_point,
lambda: f"flash_attention: Expected floating point",
)

for q_b, k_b, v_b in zip(q_bs, k_bs, v_bs):
torch._check(
q_b == k_b and q_b == v_b,
lambda: f"expected matching batch dims: {q_b}, {k_b}, {v_b}",
)

torch._check(q_d == k_d, lambda: f"expected matching qk features: {q_d}, {k_d}")

torch._check(k_s == v_s, lambda: f"expected matching kv length: {q_d}, {k_d}")

q_desc.specialize_dims(0, 1, -1)
k_desc.specialize_dims(0, 1, -1)
v_desc.specialize_dims(0, 1, -1)

# Result 0: Shape batch..., m, n
ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float32).specialize_dims(
0, 1, -1
)

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
q = kb.arg_value(0)
k = kb.arg_value(1)
v = kb.arg_value(2)
a = kb.arg_value(3)
scale = kb.arg_value(4)

q_tensor_type = RankedTensorType(q.type)
scale_tensor_type = RankedTensorType(scale.type)
v_tensor_type = RankedTensorType(v.type)

b1, b2, l, d = q_tensor_type.shape
_, _, s, e = v_tensor_type.shape

# Unspecialized dims will be negative
l = l if l >= 0 else "?"
s = s if s >= 0 else "?"
b = str(int(b1) * int(b2))
i_type_str = str(q_tensor_type.element_type)
scale_type_str = str(scale_tensor_type.element_type)
o_type_str = "f32"

target_function_name = f"sharktank_masked_flash_attention_{b1}_{b2}_{d}_{e}_{i_type_str}_{scale_type_str}_{o_type_str}"
kwargs = {
"b": b,
"b1": b1,
"b2": b2,
"l": l,
"d": d,
"s": s,
"e": e,
"i_dtype": i_type_str,
"scale_dtype": scale_type_str,
"o_dtype": o_type_str,
"func_name": target_function_name,
}
template_file = "masked_flash_attention.mlir"
target_function = inline_template_function(
kb,
template_file,
target_function_name,
**kwargs,
)
kb.yield_results(*call_function(target_function, q, k, v, scale, a))
pass


@CustomOp.register(library=LIBRARY)
class flash_attention(CustomOp):

Expand Down
9 changes: 4 additions & 5 deletions sharktank/sharktank/kernels/batch_matmul_transpose_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from iree.compiler.ir import IntegerType
from iree.compiler.ir import IntegerType, FloatType

__all__ = [
"batch_matmul_transpose_b",
Expand Down Expand Up @@ -59,9 +59,7 @@ def select(self, ksel: KernelSelection):
lambda: f"batch_matmul_transpose_b: Batch dims must match ({lhs_desc.t.shape} vs {rhs_desc.t.shape})",
)
# Shape batch, m, n
c_desc = ksel.return_new_tensor(
[lhs_batch, lhs_m, rhs_n], dtype=lhs_desc.t.dtype
)
c_desc = ksel.return_new_tensor([lhs_batch, lhs_m, rhs_n], dtype=torch.float32)
specialize_all_known_dims(lhs_desc)
specialize_all_known_dims(rhs_desc)
specialize_all_known_dims(c_desc)
Expand All @@ -77,8 +75,9 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
result_desc = ksel.result_descs[0]

# Generate specialization signature and types.
a_asm_type, a_ident, accum_type = unpack_tensor_type(lhs.type)
a_asm_type, a_ident, _ = unpack_tensor_type(lhs.type)
b_asm_type, b_ident, _ = unpack_tensor_type(rhs.type)
accum_type = FloatType.parse("f32")
spec_sig = f"L{a_ident}_R{b_ident}"
template_file = "batch_matmul_transpose_b.mlir"
target_function_name = f"sharktank_batch_matmul_transpose_b_{spec_sig}"
Expand Down
62 changes: 62 additions & 0 deletions sharktank/sharktank/kernels/templates/masked_flash_attention.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 Advanced Micro Devices, Inc.
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

!q_type = tensor<{{b1}}x{{b2}}x{{l}}x{{d}}x{{i_dtype}}>
!k_type = tensor<{{b1}}x{{b2}}x{{s}}x{{d}}x{{i_dtype}}>
!v_type = tensor<{{b1}}x{{b2}}x{{s}}x{{e}}x{{i_dtype}}>
!a_type = tensor<{{l}}x{{s}}x{{i_dtype}}>
!trans_v_type = tensor<{{b1}}x{{b2}}x{{e}}x{{s}}x{{i_dtype}}>
!o_type = tensor<{{b1}}x{{b2}}x{{l}}x{{e}}x{{o_dtype}}>
!o_dyn_type = tensor<?x?x?x{{o_dtype}}>
!o_collapsed_type = tensor<{{b}}x{{l}}x{{e}}x{{o_dtype}}>
!q_collapsed_type = tensor<{{b}}x{{l}}x{{d}}x{{i_dtype}}>
!k_collapsed_type = tensor<{{b}}x{{s}}x{{d}}x{{i_dtype}}>
!v_collapsed_type = tensor<{{b}}x{{s}}x{{e}}x{{i_dtype}}>
!a_collapsed_type = tensor<{{l}}x{{s}}x{{i_dtype}}>
!s_type = tensor<{{scale_dtype}}>

module {

util.func private @{{func_name}}(
%q: !q_type,
%k: !k_type,
%v: !v_type,
%s: !s_type,
%a: !a_type) -> !o_type {

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%b0 = arith.constant {{b}} : index


%l = tensor.dim %q, %c2 : !q_type
%e = tensor.dim %v, %c3 : !v_type

%scale = tensor.extract %s[] : !s_type
%empty_dyn = tensor.empty(%b0, %l, %e) : !o_dyn_type
%empty = tensor.cast %empty_dyn : !o_dyn_type to !o_collapsed_type

%collapsed_q = tensor.collapse_shape %q [[0, 1], [2], [3]] : !q_type into !q_collapsed_type
%collapsed_k = tensor.collapse_shape %k [[0, 1], [2], [3]] : !k_type into !k_collapsed_type
%collapsed_v = tensor.collapse_shape %v [[0, 1], [2], [3]] : !v_type into !v_collapsed_type

%atten = iree_linalg_ext.attention {indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>]}
ins(%collapsed_q, %collapsed_k, %collapsed_v, %scale, %a : !q_collapsed_type, !k_collapsed_type, !v_collapsed_type, {{scale_dtype}}, !a_collapsed_type) outs(%empty : !o_collapsed_type) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> !o_collapsed_type
%expanded_o = tensor.expand_shape %atten [[0,1], [2], [3]] output_shape [{{b1}}, {{b2}}, %l, {{e}}] : !o_collapsed_type into !o_type
util.return %expanded_o : !o_type
}
}
2 changes: 2 additions & 0 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def attention_mask(

# Combine the causal context mask and input mask.
dtype = self.attention_dtype
print("attention dtype")
print(self.attention_dtype)
_, batch_seq_len = input_mask.shape
causal_mask = causal_context_mask[:, :, :batch_seq_len, :batch_seq_len]
boolean_mask = torch.logical_or(causal_mask, input_mask[:, None, None, :])
Expand Down
8 changes: 6 additions & 2 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
if self.q_input is not None and self.qdq_input is not None:
raise AssertionError(f"LinearLayer cannot have both q_input and qdq_input")
self.qdq_output: Optional[QuantizedTensor] = theta.optional_tensor("qdq_output")
self.q_output: Optional[QuantizerTensor] = theta.optional_tensor("q_output")

def forward(self, x):
weight = self.weight
Expand All @@ -79,14 +80,17 @@ def forward(self, x):

y = ops.linear(x, weight, bias)
# Unconditionally dequantize.
if self.q_output is not None:
y = self.q_output.quantize(y)
return y.unpack().qs
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
if not isinstance(y, QuantizedTensor) and isinstance(x, QuantizedTensor):
if x.unpack().qs.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.bfloat16)
return y
if qdq_output is not None:
Expand Down
16 changes: 16 additions & 0 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .rotary_embedding import RotaryEmbeddingLayer
from .kv_cache import PagedKVCache
from .. import ops
from .. import kernels

__all__ = [
"PagedLlamaAttentionBlock",
Expand Down Expand Up @@ -74,6 +75,9 @@ def __init__(
self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor(
"kv_cache.quantizer"
)
self.attention_scale = None
if "attn_scale" in theta.keys:
self.attention_scale = theta("attn_scale").as_torch()

if theta.optional_tensor("attn_output_norm") is None:
self.add_module(
Expand Down Expand Up @@ -197,6 +201,18 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
attn_output = ops.matmul(
attn_weights, values
) # (bs, heads, slen, head_dim)
elif self.attention_kernel == "sharktank":
assert self.attention_scale is not None
if attention_mask is not None:
attn_output = kernels.masked_flash_attention(
xq,
keys,
values,
attention_mask.squeeze(0).squeeze(0),
self.attention_scale,
)
else:
attn_output = kernels.flash_attention(xq, keys, values)
else:
attn_output = ops.scaled_dot_product_attention(
q=xq, # [bs, ..., sl, dim]
Expand Down
45 changes: 30 additions & 15 deletions sharktank/sharktank/models/llama/tools/import_quark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def quantize_weight(
weight_quant_zero_point,
)
# we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences
# scales are multipled by two to account for the difference between fnuz and fn
if input_quant_scale is not None:
updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer(
name=new_layer_name + ".q_input",
Expand All @@ -218,10 +219,10 @@ def quantize_weight(
dtype=torch.float8_e4m3fnuz,
)
if output_quant_scale is not None:
updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer(
name=new_layer_name + ".qdq_output",
scale=1.0 / output_quant_scale,
reciprocal_scale=output_quant_scale,
updated_tensors[new_layer_name + ".q_output"] = StaticScaledQuantizer(
name=new_layer_name + ".q_output",
scale=1.0 / (output_quant_scale * 2.0),
reciprocal_scale=output_quant_scale * 2.0,
dtype=torch.float8_e4m3fnuz,
)

Expand Down Expand Up @@ -258,15 +259,29 @@ def update_norm_layer(
sub_name = layer_name + "." + sub
new_name = hf_to_gguf(sub_name) + ".weight"
single_replace(quant_theta, sub_name, new_name, updated_tensors)
kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch()
layer_idx = layer_name.split(".")[-1]
new_name = f"blk.{layer_idx}.kv_cache"
updated_tensors[new_name] = StaticScaledQuantizer(
name=new_name + ".quantizer",
scale=1.0 / (kv_cache_scale * 2.0),
reciprocal_scale=kv_cache_scale * 2.0,
dtype=torch.float8_e4m3fnuz,
)
if "kv_cache_scale" in quant_theta(layer_name, "self_attn").keys:
kv_cache_scale = (
quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch()
)
new_name = f"blk.{layer_idx}.kv_cache"
updated_tensors[new_name] = StaticScaledQuantizer(
name=new_name + ".quantizer",
scale=1.0 / (kv_cache_scale * 2.0),
reciprocal_scale=kv_cache_scale * 2.0,
dtype=torch.float8_e4m3fnuz,
)
if "prob_output_scale" in quant_theta(layer_name, "self_attn").keys:
prob_output_scale = (
quant_theta(layer_name, "self_attn").tensor("prob_output_scale").as_torch()
* 2.0
)
new_name = f"blk.{layer_idx}.attn_scale"
updated_tensors[new_name] = DefaultPrimitiveTensor(
name=new_name, data=prob_output_scale
)
print("added attn_scale", new_name)
print(prob_output_scale)


def single_replace(
Expand Down Expand Up @@ -298,16 +313,16 @@ def main(argv):
type=str,
default="7b",
help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.",
choices=["7b", "70b"],
choices=["7b", "70b", "405b"],
)
args = cli.parse(parser, args=argv)

config_json_path: Path = args.config_json
params_path: Path = args.params
# TODO: find a way to get this programatically so we don't have to flag for it
split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024]
num_layers = 32 if args.model_base == "7b" else 80

layers_per_base = {"7b": 32, "70b": 40, "405b": 125}
num_layers = layers_per_base[args.model_base]
# Construct the pre-transform dataset.
dataset_props = _get_dataset_props(_load_json(config_json_path))
with safetensors.safe_open(params_path, framework="pt", device="cpu") as st:
Expand Down
Loading
Loading