-
Notifications
You must be signed in to change notification settings - Fork 75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support HQT on VLLM #59
Changes from all commits
4b9b955
f3ffc8c
8ffc3d0
f5f0972
c521c4d
64c8c7f
2e291c5
9d0fbb7
09e0078
24847a9
90c2527
f7c2157
608123b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,7 @@ | |
|
||
import vllm.hpu.utils as hpu_utils | ||
|
||
# FIXME: For some reason splitting value causes DFAs on G3. This needs to be debugged | ||
PA_SPLIT_VALUE_DEFAULT = '0' if (htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi3) else '1' | ||
PA_SPLIT_VALUE_DEFAULT = '1' | ||
PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', PA_SPLIT_VALUE_DEFAULT) == '1') | ||
|
||
|
||
|
@@ -38,8 +37,13 @@ def fetch_from_cache(cache, blocks, permutations): | |
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))] | ||
|
||
|
||
@hpu_utils.with_mark_steps | ||
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None: | ||
def permute_cache(cache, permutations): | ||
return [v.permute(permutations) for v in cache] | ||
|
||
|
||
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is @hpu_utils.with_mark_steps removed here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we want that all convert to\from hf8 will be in the same graph |
||
qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache, | ||
keys_permute=permute_cache) -> None: | ||
seq_len = block_tables.size(1) | ||
batch_size, query_heads, _ = query.shape | ||
_, _, kv_heads, _ = key_cache.shape | ||
|
@@ -51,26 +55,27 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block | |
.view(batch_size, 1, 1, -1)) | ||
query.mul_(scale) | ||
query = query.unsqueeze(-2) | ||
keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) | ||
keys = keys_fetch_func(key_cache, block_tables, (0, 2, 1, 3)) | ||
if query_heads != kv_heads: | ||
query = query.unflatten(1, (kv_heads, -1)) | ||
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] | ||
keys = keys_permute(keys, (0, 1, 2, 4, 3)) | ||
mask = mask.unsqueeze(2) | ||
else: | ||
keys = keys_permute(keys, (0, 1, 3, 2)) | ||
attn_weights = [qk_matmul_op(query, k) for k in keys] | ||
attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf), | ||
dim=-1) | ||
|
||
attn_weights = [torch.matmul(query, k) for k in keys] | ||
attn_weights = (torch.cat(attn_weights, dim=-1) | ||
.masked_fill(mask, min_inf) | ||
.softmax(dim=-1)) | ||
|
||
values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) | ||
values = values_fetch_func(value_cache, block_tables, (0, 2, 1, 3)) | ||
if PA_SPLIT_VALUE: | ||
attn_weights = attn_weights.split(block_size, dim=-1) | ||
else: | ||
values = [torch.cat(values, dim=-2)] | ||
attn_weights = [attn_weights] | ||
if query_heads != kv_heads: | ||
values = [v.unflatten(1, (kv_heads, 1)) for v in values] | ||
attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)] | ||
attn_weights = [kv_matmul_op(a, v) for a, v in zip(attn_weights, values)] | ||
if query_heads != kv_heads: | ||
attn_weights = [a.flatten(1, 2) for a in attn_weights] | ||
attn_weights = sum(attn_weights) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@madamczykhabana can you check if this is fine? aren't we making HabanaPagedAttention.write_to_paged_cache useless here?