diff --git a/docs/tutorials/llm.rst b/docs/tutorials/llm.rst index 4cb02e6a0..e9690b677 100644 --- a/docs/tutorials/llm.rst +++ b/docs/tutorials/llm.rst @@ -30,14 +30,14 @@ Verified for distributed inference mode via DeepSpeed *Note*: The above verified models (including other models in the same model family, like "codellama/CodeLlama-7b-hf" from LLAMA family) are well supported with all optimizations like indirect access KV cache, fused ROPE, and prepacked TPP Linear (fp32/bf16). We are working in progress to better support the models in the tables with various data types. In addition, more models will be optimized in the future. -Please check `LLM best known practice <../../examples/cpu/inference/python/llm>`_ for instructions to install/setup environment and example scripts. +Please check `LLM best known practice `_ for instructions to install/setup environment and example scripts. Module Level Optimization API for customized LLM (Prototype) ------------------------------------------------------------ In the past year, LLM has been flourishing with many open-sourced models contributed to the community, while researchers are building their own LLMs from transformer blocks with variants in implementation details. To help LLM researchers and developers improve their productivity, IntelĀ® Extension for PyTorch* provides module level optimizations for commonly used LLM modules and functionalities, which are operators or certain operator combinations in nature. -Please check `LLM module level optimization practice <../../examples/cpu/inference/python/llm-modeling>`_ to better understand how to use `module level APIs `_ to optimize your LLM and achieve better performance. +Please check `LLM module level optimization practice `_ to better understand how to use `module level APIs `_ to optimize your LLM and achieve better performance. Demos ----- diff --git a/intel_extension_for_pytorch/llm/functional/fusions.py b/intel_extension_for_pytorch/llm/functional/fusions.py index 7251bc525..d58a670f6 100644 --- a/intel_extension_for_pytorch/llm/functional/fusions.py +++ b/intel_extension_for_pytorch/llm/functional/fusions.py @@ -20,25 +20,31 @@ def rotary_embedding( ): r""" Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864) - on the `query ` or `key` before their multi-head attention computation. + on the `query ` or `key` before their multi-head attention computation. + Args: - - query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of - [batch size, sequence length, num_head/num_kv_head, head_dim] - or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape). - - sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key. - - rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. - - head_dim (int) : head dim from the input shape. - - rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, - so the offset is 1. - if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements, - so the offset is rotary_dim/2. - - position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids - for the input. The shape should be [batch size, sequence length]. + query, key (torch.Tensor) : inputs to be applied with position embeddings, + taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] + or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape). + sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor + generated to be applied on query/key. + rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. + head_dim (int) : head dim from the input shape. + rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, + so the offset is 1. + + if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements, + so the offset is rotary_dim/2. + + position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. + The according position_ids for the input. The shape should be [batch size, sequence length]. + Return - - query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] - or [num_tokens, num_head/num_kv_head, head_dim]. + query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] + or [num_tokens, num_head/num_kv_head, head_dim]. """ + return RotaryEmbedding.apply_function( query, key, sin, cos, rotary_dim, rotary_half, position_ids ) @@ -48,12 +54,14 @@ def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, eps: float): r""" Applies RMSnorm on the input (hidden states). (see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76) + Args: - - hidden_states(torch.Tensor) : the input tensor to apply RMSNorm. - - weight (torch.Tensor): the weight to apply RMSnorm. - - eps (float) : the variance_epsilon to apply RMSnorm. + hidden_states(torch.Tensor) : the input tensor to apply RMSNorm. + weight (torch.Tensor): the weight to apply RMSnorm. + eps (float) : the variance_epsilon to apply RMSnorm. """ + return RMSNorm.apply_function(hidden_states, weight, eps) @@ -67,12 +75,14 @@ def fast_layer_norm( r""" Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) on the input (hidden states). + Args: - - hidden_states(torch.Tensor) : the input tensor to apply normalization. - - normalized_shape (int or list) or torch.Size) input shape from an expected input of size. - - weight (torch.Tensor): the weight to apply normalization. - - bias (torch.Tensor): an additive bias for normalization. - - eps (float): a value added to the denominator for numerical stability. + hidden_states(torch.Tensor) : the input tensor to apply normalization. + normalized_shape (int or list) or torch.Size) input shape from an + expected input of size. + weight (torch.Tensor): the weight to apply normalization. + bias (torch.Tensor): an additive bias for normalization. + eps (float): a value added to the denominator for numerical stability. """ @@ -103,33 +113,49 @@ def indirect_access_kv_cache_attention( buffers(key and value use different buffers) to store all key/value hidden states and beam index information. It can use beam index history to decide which beam should be used by a timestamp and this information will generate an offset to access the kv_cache buffer. + Data Format: - - The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size], - the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token. - All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch]. - - forward - - query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim). - - key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim). - - value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim). - - scale_attn (float):scale used by the attention layer. should be the sqrt(head_size). - - layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx). - key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); - value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); - beam-idx: history beam idx, shape:(max_seq, beam*batch); - seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq). - - head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet. - - attention_mask(torch.Tensor): Attention mask information. - - text_max_length (int) : the max length of kv cache to be used for generation (allocate the pre-cache buffer). + + The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size], + the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token. + All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch]. + + Args: + query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim). + key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim). + value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim). + scale_attn (float):scale used by the attention layer. should be the sqrt(head_size). + layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx). + + - key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); + + - value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); + + - beam-idx: history beam idx, shape:(max_seq, beam*batch); + + - seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq). + + head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet. + attention_mask(torch.Tensor): Attention mask information. + text_max_length (int) : the max length of kv cache to be used for generation + (allocate the pre-cache buffer). Return: - - attn_output: weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size). - - attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now. - - new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx). + attn_output: weighted value which is the output of scale dot product. + shape (beam*batch, seq_len, head_num, head_size). + + attn_weights: the output tensor of the first matmul in scale dot product + which is not supported by kernel now. + + new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx). Notes: - - How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model - see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318) + How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model + see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318) + + .. highlight:: python + .. code-block:: python + def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: @@ -141,6 +167,7 @@ def _reorder_cache( return past_key_values """ + return IndirectAccessKVCacheAttention.apply_function( query, key, @@ -174,23 +201,30 @@ def varlen_attention( ): r""" Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value - (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), - and accept the variant (different) sequence length among the query, key and value. + (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), + and accept the variant (different) sequence length among the query, key and value. + + This module does not have args for `module init`. + + `forward()` Args: - module init: this module does not have args for module init - forward: - - query (torch.Tensor): shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size. - - key (torch.Tensor): shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size. - - value (torch.Tensor): shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size. - - out (torch.Tensor): buffer to get the results, the shape is the same as query. - - seqlen_q (torch.Tensor): shape [batch_size + 1], points the current query_tokens among total sequence length. - - seqlen_k (torch.Tensor): shape [batch_size + 1], points the current key_tokens among total sequence length. - - max_seqlen_q (int): max/total sequence length of query. - - max_seqlen_k (int): max/total sequence length of key. - - pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0. - - softmax_scale (float): scaling factor applied is prior to softmax. - - is_causal (bool): whether to apply causal attention masking, default is True. + query (torch.Tensor): shape [query_tokens, num_head, head_size], + where tokens is total sequence length among batch size. + key (torch.Tensor): shape [key_tokens, num_head, head_size], + where tokens is total sequence length among batch size. + value (torch.Tensor): shape [value_tokens, num_head, head_size], + where tokens is total sequence length among batch size. + out (torch.Tensor): buffer to get the results, the shape is the same as query. + seqlen_q (torch.Tensor): shape [batch_size + 1], + points the current query_tokens among total sequence length. + seqlen_k (torch.Tensor): shape [batch_size + 1], + points the current key_tokens among total sequence length. + max_seqlen_q (int): max/total sequence length of query. + max_seqlen_k (int): max/total sequence length of key. + pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0. + softmax_scale (float): scaling factor applied is prior to softmax. + is_causal (bool): whether to apply causal attention masking, default is True. """ return VarlenAttention.apply_function( diff --git a/intel_extension_for_pytorch/llm/modules/linear_fusion.py b/intel_extension_for_pytorch/llm/modules/linear_fusion.py index 380cf8de4..26e4c99d3 100644 --- a/intel_extension_for_pytorch/llm/modules/linear_fusion.py +++ b/intel_extension_for_pytorch/llm/modules/linear_fusion.py @@ -53,12 +53,21 @@ def init_on_device(self, x, op_type): class LinearSilu(IPEXLinearFusion): r""" Applies a linear transformation to the `input` data, and then apply PyTorch SILU - (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) on the result: + (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) + on the result: + + .. highlight:: python + .. code-block:: python + result = torch.nn.functional.silu(linear(input)) + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with silu. + linear (torch.nn.Linear module) : the original torch.nn.Linear + module to be fused with silu. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -66,6 +75,7 @@ class LinearSilu(IPEXLinearFusion): >>> # module forward: >>> input = torch.randn(4096, 4096) >>> result = ipex_fusion(input) + """ def __init__(self, linear): @@ -80,15 +90,25 @@ def forward(self, x): class Linear2SiluMul(IPEXLinear2Fusion): r""" - Applies two linear transformation to the `input` data (`linear_s` and `linear_m`), then apply PyTorch SILU - (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) on the result from `linear_s` - , and multiplies the result from `linear_m`: + Applies two linear transformation to the `input` data (`linear_s` and + `linear_m`), then apply PyTorch SILU + (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) + on the result from `linear_s`, and multiplies the result from `linear_m`: + + .. highlight:: python + .. code-block:: python + result = torch.nn.functional.silu(linear_s(input)) * linear_m(input) + Args: - linear_s (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with silu. - linear_m (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with mul. + linear_s (torch.nn.Linear module) : the original torch.nn.Linear + module to be fused with silu. + linear_m (torch.nn.Linear module) : the original torch.nn.Linear + module to be fused with mul. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_s_module = torch.nn.Linear(4096, 4096) @@ -97,6 +117,7 @@ class Linear2SiluMul(IPEXLinear2Fusion): >>> # module forward: >>> input = torch.randn(4096, 4096) >>> result = ipex_fusion(input) + """ def __init__(self, linear_s, linear_m): @@ -112,12 +133,21 @@ def forward(self, x): class LinearRelu(IPEXLinearFusion): r""" Applies a linear transformation to the `input` data, and then apply PyTorch RELU - (see https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html) on the result: + (see https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html) + on the result: + + .. highlight:: python + .. code-block:: python + result = torch.nn.functional.relu(linear(input)) + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with relu. + linear (torch.nn.Linear module) : the original torch.nn.Linear module + to be fused with relu. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -125,6 +155,7 @@ class LinearRelu(IPEXLinearFusion): >>> # module forward: >>> input = torch.randn(4096, 4096) >>> result = ipex_fusion(input) + """ def __init__(self, linear): @@ -142,11 +173,19 @@ class LinearNewGelu(IPEXLinearFusion): Applies a linear transformation to the `input` data, and then apply NewGELUActivation (see https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L50) on the result: + + .. highlight:: python + .. code-block:: python + result = NewGELUActivation(linear(input)) + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with new_gelu. + linear (torch.nn.Linear module) : the original torch.nn.Linear module + to be fused with new_gelu. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -154,6 +193,7 @@ class LinearNewGelu(IPEXLinearFusion): >>> # module forward: >>> input = torch.randn(4096, 4096) >>> result = ipex_fusion(input) + """ def __init__(self, linear): @@ -169,12 +209,21 @@ def forward(self, x): class LinearGelu(IPEXLinearFusion): r""" Applies a linear transformation to the `input` data, and then apply PyTorch GELU - (see https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html) on the result: + (see https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html) + on the result: + + .. highlight:: python + .. code-block:: python + result = torch.nn.functional.gelu(linear(input)) + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with gelu. + linear (torch.nn.Linear module) : the original torch.nn.Linear + module to be fused with gelu. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -182,6 +231,7 @@ class LinearGelu(IPEXLinearFusion): >>> # module forward: >>> input = torch.randn(4096, 4096) >>> result = ipex_fusion(input) + """ def __init__(self, linear): @@ -199,12 +249,19 @@ class LinearSiluMul(IPEXLinearFusion): Applies a linear transformation to the `input` data, then apply PyTorch SILU (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) on the result, and multiplies the result by `other`: + + .. highlight:: python + .. code-block:: python + result = torch.nn.functional.silu(linear(input)) * other + Args: linear (torch.nn.Linear module) : the original torch.nn.Linear module to - be fused with silu and mul. + be fused with silu and mul. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -213,6 +270,7 @@ class LinearSiluMul(IPEXLinearFusion): >>> input = torch.randn(4096, 4096) >>> other = torch.randn(4096, 4096) >>> result = ipex_fusion(input, other) + """ def __init__(self, linear): @@ -227,12 +285,21 @@ def forward(self, x, y): class LinearMul(IPEXLinearFusion): r""" - Applies a linear transformation to the `input` data, and then multiplies the result by `other`: + Applies a linear transformation to the `input` data, and then multiplies + the result by `other`: + + .. highlight:: python + .. code-block:: python + result = linear(input) * other + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with mul. + linear (torch.nn.Linear module) : the original torch.nn.Linear module + to be fused with mul. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -241,6 +308,7 @@ class LinearMul(IPEXLinearFusion): >>> input = torch.randn(4096, 4096) >>> other = torch.randn(4096, 4096) >>> result = ipex_fusion(input, other) + """ def __init__(self, linear): @@ -255,12 +323,21 @@ def forward(self, x, y): class LinearAdd(IPEXLinearFusion): r""" - Applies a linear transformation to the `input` data, and then add the result by `other`: + Applies a linear transformation to the `input` data, + and then add the result by `other`: + + .. highlight:: python + .. code-block:: python + result = linear(input) + other + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with add. + linear (torch.nn.Linear module) : the original torch.nn.Linear + module to be fused with add. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -269,6 +346,7 @@ class LinearAdd(IPEXLinearFusion): >>> input = torch.randn(4096, 4096) >>> other = torch.randn(4096, 4096) >>> result = ipex_fusion(input, other) + """ def __init__(self, linear): @@ -283,12 +361,21 @@ def forward(self, x, y): class LinearAddAdd(IPEXLinearFusion): r""" - Applies a linear transformation to the `input` data, and then add the result by `other_1` and `other_2`: + Applies a linear transformation to the `input` data, + and then add the result by `other_1` and `other_2`: + + .. highlight:: python + .. code-block:: python + result = linear(input) + other_1 + other_2 + Args: - linear (torch.nn.Linear module) : the original torch.nn.Linear module to be fused with add and add. + linear (torch.nn.Linear module) : the original torch.nn.Linear + module to be fused with add and add. + Shape: Input and output shapes are the same as torch.nn.Linear. + Examples: >>> # module init: >>> linear_module = torch.nn.Linear(4096, 4096) @@ -298,6 +385,7 @@ class LinearAddAdd(IPEXLinearFusion): >>> other_1 = torch.randn(4096, 4096) >>> other_2 = torch.randn(4096, 4096) >>> result = ipex_fusion(input, other_1, other_2) + """ def __init__(self, linear): diff --git a/intel_extension_for_pytorch/llm/modules/mha_fusion.py b/intel_extension_for_pytorch/llm/modules/mha_fusion.py index fc9d3e62b..99a66d6da 100644 --- a/intel_extension_for_pytorch/llm/modules/mha_fusion.py +++ b/intel_extension_for_pytorch/llm/modules/mha_fusion.py @@ -7,30 +7,34 @@ class RotaryEmbedding(nn.Module): r""" [module init and forward] Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864) - on the `query ` or `key` before their multi-head attention computation. + on the ``query`` or ``key`` before their multi-head attention computation. + + `module init` + + Args: + max_position_embeddings (int): size (max) of the position embeddings. + pos_embd_dim (int): dimension of the position embeddings. + base (int) : Default: 10000. Base to generate the frequency of position embeddings. + backbone (str): Default: None. The exact transformers model backbone + (e.g., "GPTJForCausalLM", get from model.config.architectures[0], + see https://huggingface.co/EleutherAI/gpt-j-6b/blob/main/config.json#L4). + + `forward()` + Args: - module init: - - max_position_embeddings (int): size (max) of the position embeddings. - - pos_embd_dim (int): dimension of the position embeddings. - - base (int) : Default: 10000. Base to generate the frequency of position embeddings. - - backbone (str): Default: None. The exact transformers model backbone - (e.g., "GPTJForCausalLM", get from model.config.architectures[0], - see https://huggingface.co/EleutherAI/gpt-j-6b/blob/main/config.json#L4). - - forward: - - input (torch.Tensor) : input to be applied with position embeddings, - taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] - (as well as the output shape). - - position_ids (torch.Tensor): the according position_ids for the input. - The shape should be [batch size, sequence length. In some cases, - there is only one element which the past_kv_length, and position id - can be constructed by past_kv_length + current_position. - - num_head (int) : head num from the input shape. - - head_dim (int) : head dim from the input shape. - - offset (int) : the offset value. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, - so the offset is 1. For llama, cos/sin is applied to the neighboring rotary_dim elements, - so the offset is rotary_dim/2. - - rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. + input (torch.Tensor) : input to be applied with position embeddings, + taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] + (as well as the output shape). + position_ids (torch.Tensor): the according position_ids for the input. + The shape should be [batch size, sequence length. In some cases, + there is only one element which the past_kv_length, and position id + can be constructed by past_kv_length + current_position. + num_head (int) : head num from the input shape. + head_dim (int) : head dim from the input shape. + offset (int) : the offset value. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, + so the offset is 1. For llama, cos/sin is applied to the neighboring rotary_dim elements, + so the offset is rotary_dim/2. + rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. Examples: >>> # module init: @@ -40,25 +44,29 @@ class RotaryEmbedding(nn.Module): >>> position_ids = torch.arange(32).unsqueeze(0) >>> query_rotery = rope_module(query, position_ids, 16, 256, 1, 64) - [Direct function call] This module also provides a `.apply_function` function call to be used on query and key - at the same time without initializing the module (assume rotary embedding - sin/cos values are provided). + [Direct function call] This module also provides a `.apply_function` function call + to be used on query and key at the same time without initializing the module + (assume rotary embedding sin/cos values are provided). + + `apply_function()` + Args: - - query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of - [batch size, sequence length, num_head/num_kv_head, head_dim] - or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape). - - sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key. - - rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. - - head_dim (int) : head dim from the input shape. - - rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, - so the offset is 1. - if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements, - so the offset is rotary_dim/2. - - position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids - for the input. The shape should be [batch size, sequence length]. - Return - - query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] - or [num_tokens, num_head/num_kv_head, head_dim]. + query, key (torch.Tensor) : inputs to be applied with position embeddings, taking shape of + [batch size, sequence length, num_head/num_kv_head, head_dim] + or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape). + sin/cos (torch.Tensor): [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key. + rotary_ndims (int): the rotary dimension. e.g., 64 for GPTJ. head size for LLama. + head_dim (int) : head dim from the input shape. + rotary_half (bool) : if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, + so the offset is 1. + if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements, + so the offset is rotary_dim/2. + position_ids (torch.Tensor): Default is None and optional if sin/cos is provided. the according position_ids + for the input. The shape should be [batch size, sequence length]. + + Return: + query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] + or [num_tokens, num_head/num_kv_head, head_dim]. """ @@ -149,16 +157,20 @@ class FastLayerNorm(nn.Module): r""" [module init and forward] Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) on the input (hidden states). + + `module init` + Args: - module init: - - normalized_shape ((int or list) or torch.Size) input shape from an expected input of size. - - eps (float): a value added to the denominator for numerical stability. - - weight (torch.Tensor): the weight of Layernorm to apply normalization. - - bias (torch.Tensor): an additive bias for normalization. + normalized_shape ((int or list) or torch.Size) input shape from an expected input of size. + eps (float): a value added to the denominator for numerical stability. + weight (torch.Tensor): the weight of Layernorm to apply normalization. + bias (torch.Tensor): an additive bias for normalization. - forward: - - hidden_states (torch.Tensor) : input to be applied Layernorm, usually taking shape of - [batch size, sequence length, hidden_size] (as well as the output shape). + `forward()` + + Args: + hidden_states (torch.Tensor) : input to be applied Layernorm, usually taking shape of + [batch size, sequence length, hidden_size] (as well as the output shape). Examples: >>> # module init: @@ -169,13 +181,16 @@ class FastLayerNorm(nn.Module): >>> result = layernorm_module(input) [Direct function call] This module also provides a `.apply_function` function call to apply fast layernorm - without initializing the module. + without initializing the module. + + `apply_function()` + Args: - - hidden_states(torch.Tensor) : the input tensor to apply normalization. - - normalized_shape (int or list) or torch.Size) input shape from an expected input of size. - - weight (torch.Tensor): the weight to apply normalization. - - bias (torch.Tensor): an additive bias for normalization. - - eps (float): a value added to the denominator for numerical stability. + hidden_states(torch.Tensor) : the input tensor to apply normalization. + normalized_shape (int or list) or torch.Size) input shape from an expected input of size. + weight (torch.Tensor): the weight to apply normalization. + bias (torch.Tensor): an additive bias for normalization. + eps (float): a value added to the denominator for numerical stability. """ @@ -217,16 +232,20 @@ class RMSNorm(nn.Module): r""" [module init and forward] Applies RMSnorm on the input (hidden states). (see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76) + + `module init` + Args: - module init: - - hidden_size (int) : the size of the hidden states. - - eps (float) : the variance_epsilon to apply RMSnorm, default using 1e-6. - - weight (torch.Tensor): the weight to apply RMSnorm, default None and will use `torch.ones(hidden_size)`. + hidden_size (int) : the size of the hidden states. + eps (float) : the variance_epsilon to apply RMSnorm, default using 1e-6. + weight (torch.Tensor): the weight to apply RMSnorm, default None + and will use `torch.ones(hidden_size)`. - forward: - - hidden_states (torch.Tensor) : input to be applied RMSnorm, usually taking shape of - [batch size, sequence length, hidden_size] - (as well as the output shape). + `forward()` + + Args: + hidden_states (torch.Tensor) : input to be applied RMSnorm, usually taking shape of + [batch size, sequence length, hidden_size] (as well as the output shape). Examples: >>> # module init: @@ -235,12 +254,15 @@ class RMSNorm(nn.Module): >>> input = torch.randn(1, 32, 4096) >>> result = rmsnorm_module(input) - [Direct function call] This module also provides a `.apply_function` function call to apply RMSNorm without - initializing the module. + [Direct function call] This module also provides a `.apply_function` function + call to apply RMSNorm without initializing the module. + + `apply_function()` + Args: - - hidden_states(torch.Tensor) : the input tensor to apply RMSNorm. - - weight (torch.Tensor): the weight to apply RMSnorm. - - eps (float) : the variance_epsilon to apply RMSnorm. + hidden_states(torch.Tensor) : the input tensor to apply RMSNorm. + weight (torch.Tensor): the weight to apply RMSnorm. + eps (float) : the variance_epsilon to apply RMSnorm. """ @@ -271,23 +293,31 @@ def forward(self, x: torch.Tensor): class VarlenAttention(nn.Module): r""" [module init and forward] Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value - (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), - and accept the variant (different) sequence length among the query, key and value. + (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), + and accept the variant (different) sequence length among the query, key and value. + + This module does not have args for `module init`. + + `forward()` Args: - module init: this module does not have args for module init - forward: - - query (torch.Tensor): shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size. - - key (torch.Tensor): shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size. - - value (torch.Tensor): shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size. - - out (torch.Tensor): buffer to get the results, the shape is the same as query. - - seqlen_q (torch.Tensor): shape [batch_size + 1], points the current query_tokens among total sequence length. - - seqlen_k (torch.Tensor): shape [batch_size + 1], points the current key_tokens among total sequence length. - - max_seqlen_q (int): max/total sequence length of query. - - max_seqlen_k (int): max/total sequence length of key. - - pdropout (float): dropout probability; if greater than 0.0, dropout is applied, default is 0.0. - - softmax_scale (float): scaling factor applied is prior to softmax. - - is_causal (bool): whether to apply causal attention masking, default is True. + query (torch.Tensor): shape [query_tokens, num_head, head_size], + where tokens is total sequence length among batch size. + key (torch.Tensor): shape [key_tokens, num_head, head_size], + where tokens is total sequence length among batch size. + value (torch.Tensor): shape [value_tokens, num_head, head_size], + where tokens is total sequence length among batch size. + out (torch.Tensor): buffer to get the results, the shape is the same as query. + seqlen_q (torch.Tensor): shape [batch_size + 1], points the + current query_tokens among total sequence length. + seqlen_k (torch.Tensor): shape [batch_size + 1], points the + current key_tokens among total sequence length. + max_seqlen_q (int): max/total sequence length of query. + max_seqlen_k (int): max/total sequence length of key. + pdropout (float): dropout probability; if greater than 0.0, + dropout is applied, default is 0.0. + softmax_scale (float): scaling factor applied is prior to softmax. + is_causal (bool): whether to apply causal attention masking, default is True. Examples: >>> # module init: @@ -305,10 +335,10 @@ class VarlenAttention(nn.Module): >>> softmax_scale = 0.5 >>> varlenAttention_module(query, key, value, out, seqlen_q, seqlen_k, max_seqlen_q, max_seqlen_k, pdropout, softmax_scale) - [Direct function call] This module also provides a `.apply_function` function call to apply VarlenAttention without - initializing the module. - Args: - - The parameters are the same as the forward call. + [Direct function call] This module also provides a `.apply_function` + function call to apply VarlenAttention without initializing the module. + + The parameters of `apply_function()` are the same as the `forward()` call. """ @@ -399,58 +429,65 @@ class PagedAttention: for key/value cache. The basic logic as following figure. Firstly, The DRAM buffer which includes num_blocks are pre-allocated to store key or value cache. For every block, block_size tokens can be stored. In the forward pass, the cache manager will firstly allocate some slots from this buffer and use reshape_and_cache API to store - the key/value and then use single_query_cached_kv_attention API to do the scale-dot-product of MHA. + the key/value and then use single_query_cached_kv_attention API to do the scale-dot-product of MHA. The block is basic allocation unit of paged attention and the token intra-block are stored one-by-one. The block tables are used to map the logical block of sequence into the physical block. [class method]: reshape_and_cache - ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) This operator is used to store the key/value token states into the pre-allcated kv_cache buffers of paged attention. + Args: - - key (torch.Tensor): The keytensor. The shape should be [num_seqs, num_heads, head_size]. - - value (torch.Tensor): The value tensor. The shape should be [num_seqs, num_heads, head_size]. - - key_cache (torch.Tensor): The pre-allocated buffer to store the key cache. The shape should be - [num_blocks, block_size, num_heads, head_size]. - - value_cache (torch.Tensor): The pre-allocated buffer to store the value cache. The shape should be - [num_blocks, block_size, num_heads, head_size]. - - slot_mapping (torch.Tensor): It stores the position to store the key/value in the pre-allocated buffers. - The shape should be the number of sequences. For sequence _i_, the slot_mapping[i]//block_number - can get the block index, and the slot_mapping%block_size can get the offset of this block. + key (torch.Tensor): The keytensor. The shape should be [num_seqs, num_heads, head_size]. + value (torch.Tensor): The value tensor. The shape should be [num_seqs, num_heads, head_size]. + key_cache (torch.Tensor): The pre-allocated buffer to store the key cache. + The shape should be [num_blocks, block_size, num_heads, head_size]. + value_cache (torch.Tensor): The pre-allocated buffer to store the value cache. + The shape should be [num_blocks, block_size, num_heads, head_size]. + slot_mapping (torch.Tensor): It stores the position to store the key/value in the pre-allocated buffers. + The shape should be the number of sequences. For sequence ``i``, the ``slot_mapping[i] // block_number`` + can get the block index, and the ``slot_mapping % block_size`` can get the offset of this block. [class method]: single_query_cached_kv_attention - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes - ) + + .. highlight:: python + .. code-block:: python + + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes + ) This operator is used to be calculated the scale-dot-product based on the paged attention. + Args: - - out (torch.Tensor): The output tensor with shape of [num_seqs, num_heads, head_size]. where the num_seqs - is the number of the sequence in this batch. The num_heads means the number of query - head. head_size means the head dimension. - - query (torch.Tensor): The query tensor. The shape should be [num_seqs, num_heads, head_size]. - - key_cache (torch.Tensor): The pre-allocated buffer to store the key cache. The shape should be - [num_blocks, block_size, num_heads, head_size]. - - value_cache(torch.Tensor): The pre-allocated buffer to store the value cache. The shape should be - [num_blocks, block_size, num_heads, head_size]. - - head_mapping(torch.Tensor): The mapping from the query head to the kv head. The shape should be - the number of query heads. - - scale (float): The scale used by the scale-dot-product. In general, it is: float(1.0 / (head_size ** 0.5)). - - block_tables:(torch.Tensor): The mapping table used to mapping the logical sequence to the physical sequence. - The shape should be [num_seqs, max_num_blocks_per_seq]. - - context_lens (torch.Tensor): The sequence length for every sequence. The size is [num_seqs]. - - block_size (int): The block size which means the number of token in every block. - - max_context_len (int): The max sequence length. - - alibi_slopes (torch.Tensor, optinal): which is the alibi slope with the shape of (num_heads). + out (torch.Tensor): The output tensor with shape of [num_seqs, num_heads, head_size], + where the num_seqs is the number of the sequence in this batch. The num_heads + means the number of query head. head_size means the head dimension. + query (torch.Tensor): The query tensor. The shape should be [num_seqs, num_heads, head_size]. + key_cache (torch.Tensor): The pre-allocated buffer to store the key cache. + The shape should be [num_blocks, block_size, num_heads, head_size]. + value_cache(torch.Tensor): The pre-allocated buffer to store the value cache. + The shape should be [num_blocks, block_size, num_heads, head_size]. + head_mapping(torch.Tensor): The mapping from the query head to the kv head. + The shape should be the number of query heads. + scale (float): The scale used by the scale-dot-product. + In general, it is: ``float(1.0 / (head_size ** 0.5))``. + block_tables:(torch.Tensor): The mapping table used to mapping the logical sequence + to the physical sequence. The shape should be [num_seqs, max_num_blocks_per_seq]. + context_lens (torch.Tensor): The sequence length for every sequence. The size is [num_seqs]. + block_size (int): The block size which means the number of token in every block. + max_context_len (int): The max sequence length. + alibi_slopes (torch.Tensor, optinal): which is the alibi slope with the shape of (num_heads). """ @@ -511,37 +548,55 @@ class IndirectAccessKVCacheAttention(nn.Module): buffers(key and value use different buffers) to store all key/value hidden states and beam index information. It can use beam index history to decide which beam should be used by a timestamp and this information will generate an offset to access the kv_cache buffer. + Data Format: - - The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size], - the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token. - All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch]. - [Module init and forward] + The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size], + the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token. + All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch]. + + `module init` + + Args: + text_max_length (int) : the max length of kv cache to be used + for generation (allocate the pre-cache buffer). + + `forward()` + Args: - module init - - text_max_length (int) : the max length of kv cache to be used for generation (allocate the pre-cache buffer). - - forward - - query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim). - - key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim). - - value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim). - - scale_attn (float):scale used by the attention layer. should be the sqrt(head_size). - - layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx). - key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); - value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); - beam-idx: history beam idx, shape:(max_seq, beam*batch); - seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq). - - head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet. - - attention_mask(torch.Tensor): Attention mask information. + query (torch.Tensor): Query tensor; shape: (beam*batch, seq_len, head_num, head_dim). + key (torch.Tensor): Key tensor; shape: (beam*batch, seq_len, head_num, head_dim). + value (torch.Tensor): Value tensor; shape: (beam*batch, seq_len, head_num, head_dim). + scale_attn (float):scale used by the attention layer. should be ``sqrt(head_size)``. + layer_past (tuple(torch.Tensor)): tuple(seq_info, key_cache, value_cache, beam-idx). + + - key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); + + - value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim); + + - beam-idx: history beam idx, shape:(max_seq, beam*batch); + + - seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq). + + head_mask (torch.Tensor): Head mask tensor which is not supported by kernel yet. + attention_mask(torch.Tensor): Attention mask information. Return: - - attn_output: weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size). - - attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now. - - new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx). + attn_output: Weighted value which is the output of scale dot product. + shape (beam*batch, seq_len, head_num, head_size). + + attn_weights: The output tensor of the first matmul in scale dot product + which is not supported by kernel now. + + new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx). Notes: - - How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model - see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318) + How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model + see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318) + + .. highlight:: python + .. code-block:: python + def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: @@ -552,10 +607,10 @@ def _reorder_cache( layer_past[3][layer_past[0].size(-2) - 1] = beam_idx return past_key_values - [Direct function call] This module also provides a `.apply_function` function call to apply IndirectAccessKVCacheAttention - without initializing the module. - Args: - - The parameters are the same as the forward call. + [Direct function call] This module also provides a `.apply_function` function call + to apply IndirectAccessKVCacheAttention without initializing the module. + + The parameters of `apply_function()` are the same as the `forward()` call. """