Skip to content

Commit

Permalink
fix module level API docstring (#2869)
Browse files Browse the repository at this point in the history
* correct module level api docstring

* flake8 format correction

* fix broken links
  • Loading branch information
ZailiWang authored May 11, 2024
1 parent 71d6e31 commit d3c5244
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 237 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/llm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ Verified for distributed inference mode via DeepSpeed

*Note*: The above verified models (including other models in the same model family, like "codellama/CodeLlama-7b-hf" from LLAMA family) are well supported with all optimizations like indirect access KV cache, fused ROPE, and prepacked TPP Linear (fp32/bf16). We are working in progress to better support the models in the tables with various data types. In addition, more models will be optimized in the future.

Please check `LLM best known practice <../../examples/cpu/inference/python/llm>`_ for instructions to install/setup environment and example scripts.
Please check `LLM best known practice <https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0%2Bcpu/examples/cpu/inference/python/llm>`_ for instructions to install/setup environment and example scripts.

Module Level Optimization API for customized LLM (Prototype)
------------------------------------------------------------

In the past year, LLM has been flourishing with many open-sourced models contributed to the community, while researchers are building their own LLMs from transformer blocks with variants in implementation details. To help LLM researchers and developers improve their productivity, Intel® Extension for PyTorch* provides module level optimizations for commonly used LLM modules and functionalities, which are operators or certain operator combinations in nature.

Please check `LLM module level optimization practice <../../examples/cpu/inference/python/llm-modeling>`_ to better understand how to use `module level APIs <api_doc.html#llm-module-level-optimizations>`_ to optimize your LLM and achieve better performance.
Please check `LLM module level optimization practice <https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0%2Bcpu/examples/cpu/inference/python/llm-modeling>`_ to better understand how to use `module level APIs <api_doc.html#llm-module-level-optimizations-prototype>`_ to optimize your LLM and achieve better performance.

Demos
-----
Expand Down
154 changes: 94 additions & 60 deletions intel_extension_for_pytorch/llm/functional/fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,31 @@ def rotary_embedding(
):
r"""
Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864)
on the `query ` or `key` before their multi-head attention computation.
on the `query ` or `key` before their multi-head attention computation.
Args:
- query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of
[batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
- sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.
- rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
- head_dim (int) : head dim from the input shape.
- rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
so the offset is 1.
if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
so the offset is rotary_dim/2.
- position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids
for the input. The shape should be [batch size, sequence length].
query, key (torch.Tensor) : inputs to be applied with position embeddings,
taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).
sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor
generated to be applied on query/key.
rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama.
head_dim (int) : head dim from the input shape.
rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements,
so the offset is 1.
if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements,
so the offset is rotary_dim/2.
position_ids (torch.Tensor): Default is None and optional if sin/cos is provided.
The according position_ids for the input. The shape should be [batch size, sequence length].
Return
- query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim].
query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim]
or [num_tokens, num_head/num_kv_head, head_dim].
"""

return RotaryEmbedding.apply_function(
query, key, sin, cos, rotary_dim, rotary_half, position_ids
)
Expand All @@ -48,12 +54,14 @@ def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float):
r"""
Applies RMSnorm on the input (hidden states).
(see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76)
Args:
- hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
- weight (torch.Tensor): the weight to apply RMSnorm.
- eps (float) : the variance_epsilon to apply RMSnorm.
hidden_states(torch.Tensor) : the input tensor to apply RMSNorm.
weight (torch.Tensor): the weight to apply RMSnorm.
eps (float) : the variance_epsilon to apply RMSnorm.
"""

return RMSNorm.apply_function(hidden_states, weight, eps)


Expand All @@ -67,12 +75,14 @@ def fast_layer_norm(
r"""
Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)
on the input (hidden states).
Args:
- hidden_states(torch.Tensor) : the input tensor to apply normalization.
- normalized_shape (int or list) or torch.Size) input shape from an expected input of size.
- weight (torch.Tensor): the weight to apply normalization.
- bias (torch.Tensor): an additive bias for normalization.
- eps (float): a value added to the denominator for numerical stability.
hidden_states(torch.Tensor) : the input tensor to apply normalization.
normalized_shape (int or list) or torch.Size) input shape from an
expected input of size.
weight (torch.Tensor): the weight to apply normalization.
bias (torch.Tensor): an additive bias for normalization.
eps (float): a value added to the denominator for numerical stability.
"""

Expand Down Expand Up @@ -103,33 +113,49 @@ def indirect_access_kv_cache_attention(
buffers(key and value use different buffers) to store all key/value hidden states and beam index information.
It can use beam index history to decide which beam should be used by a timestamp and this information will
generate an offset to access the kv_cache buffer.
Data Format:
- The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
forward
- query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
- key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
- value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
- scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
- layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
beam-idx: history beam idx, shape:(max_seq, beam*batch);
seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
- head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
- attention_mask(torch.Tensor): Attention mask information.
- text_max_length (int) : the max length of kv cache to be used for generation (allocate the pre-cache buffer).
The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size],
the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token.
All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].
Args:
query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).
key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).
value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).
scale_attn (float):scale used by the attention layer. should be the sqrt(head_size).
layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx).
- key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
- value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);
- beam-idx: history beam idx, shape:(max_seq, beam*batch);
- seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).
head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet.
attention_mask(torch.Tensor): Attention mask information.
text_max_length (int) : the max length of kv cache to be used for generation
(allocate the pre-cache buffer).
Return:
- attn_output: weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size).
- attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now.
- new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
attn_output: weighted value which is the output of scale dot product.
shape (beam*batch, seq_len, head_num, head_size).
attn_weights: the output tensor of the first matmul in scale dot product
which is not supported by kernel now.
new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).
Notes:
- How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model
see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)
.. highlight:: python
.. code-block:: python
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
Expand All @@ -141,6 +167,7 @@ def _reorder_cache(
return past_key_values
"""

return IndirectAccessKVCacheAttention.apply_function(
query,
key,
Expand Down Expand Up @@ -174,23 +201,30 @@ def varlen_attention(
):
r"""
Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value
(see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
and accept the variant (different) sequence length among the query, key and value.
(see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html),
and accept the variant (different) sequence length among the query, key and value.
This module does not have args for `module init`.
`forward()`
Args:
module init: this module does not have args for module init
forward:
- query (torch.Tensor): shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size.
- key (torch.Tensor): shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size.
- value (torch.Tensor): shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size.
- out (torch.Tensor): buffer to get the results, the shape is the same as query.
- seqlen_q (torch.Tensor): shape [batch_size + 1], points the current query_tokens among total sequence length.
- seqlen_k (torch.Tensor): shape [batch_size + 1], points the current key_tokens among total sequence length.
- max_seqlen_q (int): max/total sequence length of query.
- max_seqlen_k (int): max/total sequence length of key.
- pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
- softmax_scale (float): scaling factor applied is prior to softmax.
- is_causal (bool): whether to apply causal attention masking, default is True.
query (torch.Tensor): shape [query_tokens, num_head, head_size],
where tokens is total sequence length among batch size.
key (torch.Tensor): shape [key_tokens, num_head, head_size],
where tokens is total sequence length among batch size.
value (torch.Tensor): shape [value_tokens, num_head, head_size],
where tokens is total sequence length among batch size.
out (torch.Tensor): buffer to get the results, the shape is the same as query.
seqlen_q (torch.Tensor): shape [batch_size + 1],
points the current query_tokens among total sequence length.
seqlen_k (torch.Tensor): shape [batch_size + 1],
points the current key_tokens among total sequence length.
max_seqlen_q (int): max/total sequence length of query.
max_seqlen_k (int): max/total sequence length of key.
pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0.
softmax_scale (float): scaling factor applied is prior to softmax.
is_causal (bool): whether to apply causal attention masking, default is True.
"""
return VarlenAttention.apply_function(
Expand Down
Loading

0 comments on commit d3c5244

Please sign in to comment.