Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamingllm #1489

Merged
merged 35 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bc4d9f7
feat: avoid patch query
chenzhuofu Aug 24, 2024
e41f374
chore: separate apply_pos_encoding from compute_qkv
chenzhuofu Aug 25, 2024
5783cf1
chore: remove unused ptr
chenzhuofu Aug 27, 2024
ea580f7
fix: memory pointer alignment
chenzhuofu Aug 27, 2024
be93e5c
chore: minor smplification
chenzhuofu Aug 27, 2024
b6bcd4e
feat: StreamingCacheInfo
chenzhuofu Aug 27, 2024
f2634a9
feat: add streamingCache-related meta params
chenzhuofu Aug 28, 2024
828b1b8
chore: more acurate definition
chenzhuofu Aug 28, 2024
7e71229
chore: minor
chenzhuofu Aug 28, 2024
a2041ea
feat: add streamingCacheInfo
chenzhuofu Aug 28, 2024
694cedf
feat: apply_pos_encoding & update_qkv_cache, add offset control
chenzhuofu Aug 28, 2024
810721e
chore: minor rename
chenzhuofu Aug 28, 2024
70a4c2d
feat: kernel implementation for streaming cache usage
chenzhuofu Aug 30, 2024
8face9c
feat: implement position encoding for streaming cache
chenzhuofu Aug 31, 2024
e94598f
fix: params should add (de)serialization method
chenzhuofu Aug 31, 2024
f0d56ec
chore: reduce kv cache size
chenzhuofu Aug 31, 2024
419e0f8
chore: minor
chenzhuofu Sep 1, 2024
fb31261
fix: output misalignment
chenzhuofu Sep 1, 2024
fdf7b86
chore: minor
chenzhuofu Sep 1, 2024
77aa1af
fix: speculative decoding update_custom_mask only consider mask withi…
chenzhuofu Sep 1, 2024
425d770
fix: barrier_flag initial value
chenzhuofu Sep 2, 2024
263d9d6
doc: attention meta info
chenzhuofu Sep 2, 2024
689dbd6
docs: minor
chenzhuofu Sep 2, 2024
fe5a8ad
Added indexing support for streaming cache.
chenzhuofu Sep 3, 2024
3a1cf30
Merge branch 'streamingllm' of github.com:flexflow/FlexFlow into stre…
chenzhuofu Sep 3, 2024
fc4c1cd
docs: minor
chenzhuofu Sep 3, 2024
e55cc6e
chore: minor rename
chenzhuofu Sep 3, 2024
8f056af
feat: add streaming-llm logic to attention
chenzhuofu Sep 3, 2024
813e43f
fix: typo
chenzhuofu Sep 3, 2024
e1477d4
fix: minor bugs in streaming llm
chenzhuofu Sep 3, 2024
2f9ef18
fix: minor runtime bug
chenzhuofu Sep 4, 2024
b5eeb26
chore: minor output
chenzhuofu Sep 4, 2024
13850bb
fix: minor offset transition bug
chenzhuofu Sep 5, 2024
30d17a2
chore: minor
chenzhuofu Sep 5, 2024
61177ee
style: format code
chenzhuofu Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/config.linux
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ function get_build_configs() {
BUILD_CONFIGS="FF_CUDA_ARCH=${FF_CUDA_ARCH} FF_HIP_ARCH=${FF_HIP_ARCH} CUDA_DIR=${CUDA_DIR} CUDNN_DIR=${CUDNN_DIR} CUBLAS_DIR=${CUBLAS_DIR} CURAND_DIR=${CURAND_DIR} NCCL_DIR=${NCCL_DIR} FF_USE_PYTHON=${FF_USE_PYTHON} BUILD_LEGION_ONLY=${BUILD_LEGION_ONLY} FF_GASNET_CONDUIT=${FF_GASNET_CONDUIT} UCX_DIR=${UCX_DIR} FF_LEGION_NETWORKS=${FF_LEGION_NETWORKS} FF_BUILD_ALL_EXAMPLES=${FF_BUILD_ALL_EXAMPLES} FF_BUILD_ALL_INFERENCE_EXAMPLES=${FF_BUILD_ALL_INFERENCE_EXAMPLES} FF_BUILD_UNIT_TESTS=${FF_BUILD_UNIT_TESTS} FF_USE_PREBUILT_NCCL=${FF_USE_PREBUILT_NCCL} FF_USE_PREBUILT_LEGION=${FF_USE_PREBUILT_LEGION} FF_USE_ALL_PREBUILT_LIBRARIES=${FF_USE_ALL_PREBUILT_LIBRARIES} FF_USE_AVX2=${FF_USE_AVX2} FF_MAX_DIM=${FF_MAX_DIM} ROCM_PATH=${ROCM_PATH} FF_GPU_BACKEND=${FF_GPU_BACKEND} INSTALL_DIR=${INSTALL_DIR}"
}

patch -p0 $(dirname $0)/../deps/raft/cpp/include/raft/matrix/detail/select_radix.cuh $(dirname $0)/../config/raft.patch
patch -p0 --batch $(dirname $0)/../deps/raft/cpp/include/raft/matrix/detail/select_radix.cuh $(dirname $0)/../config/raft.patch

if [[ -n "$1" && ( "$1" == "CMAKE_FLAGS" || "$1" == "CUDA_PATH" ) ]]; then
. $(dirname $0)/config.inc
Expand Down
15 changes: 9 additions & 6 deletions include/flexflow/attention_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
namespace FlexFlow {

constexpr uint32_t kPagesize = 64;

inline int round_up_pages(int const num_elements) {
return (num_elements + kPagesize - 1) / kPagesize;
}

#define DISPATCH_HEADDIM(head_dim, HEAD_DIM, ...) \
switch (head_dim) { \
case 64: { \
Expand Down Expand Up @@ -93,9 +98,8 @@ class AttentionMetaData {
}
size_t batch_size = BatchConfig::max_requests_per_batch();
size_t max_num_pages =
(BatchConfig::max_spec_tree_token_num() +
BatchConfig::max_sequence_length() + kPagesize - 1) /
kPagesize;
round_up_pages(BatchConfig::max_spec_tree_token_num() +
BatchConfig::max_sequence_length());
size_t indices_size = std::max(
(batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024);
size_t custom_mask_size = BatchConfig::max_requests_per_batch() *
Expand Down Expand Up @@ -132,9 +136,8 @@ class AttentionMetaData {
"Insufficient memory size for attention metadata");
size_t batch_size = BatchConfig::max_requests_per_batch();
size_t max_num_pages =
(BatchConfig::max_spec_tree_token_num() +
BatchConfig::max_sequence_length() + kPagesize - 1) /
kPagesize;
round_up_pages(BatchConfig::max_spec_tree_token_num() +
BatchConfig::max_sequence_length());
size_t indices_size = std::max(
(batch_size + 1) * 4 + max_num_pages * batch_size, 1ul * 1024 * 1024);
size_t custom_mask_size = BatchConfig::max_requests_per_batch() *
Expand Down
39 changes: 38 additions & 1 deletion include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ class InferenceResult;
using BatchConfigFuture = Legion::Future;
using InferenceResultFuture = Legion::Future;

/*
* StreamingCacheInfo is a class that manages the streaming kv cache for
* attention operator (https://arxiv.org/abs/2309.17453), and we use it in the
* draft model. It maintains a fixed-content *sink* cache and a fixed-size
* *window* cache. The *sink* cache is the foremost part of the original kv
* cache, while the *window* cache is the backmost part of the original kv cache
* and is rolling updated. The information is per-request. Note that the
* position encoding of the q&k alters each iteration (relative position), so we
* store the *pre-pos-encoding* kv value in the cache.
*/
class StreamingCacheInfo {
public:
StreamingCacheInfo();
StreamingCacheInfo(int sink_cache_size, int window_cache_size);
StreamingCacheInfo(StreamingCacheInfo const &other);

StreamingCacheInfo &operator=(StreamingCacheInfo const &other);

void commit_cache(int len);
void reset_cache();
int global_2_cache_index(int global_index);

public:
int sink_cache_size, window_cache_size;
// the meta info of the window cache, commit_len helps to determine if we fill
// up the window.
int window_back, commit_len;
};

class BatchConfig {
public:
using RequestGuid = size_t;
Expand All @@ -41,6 +70,7 @@ class BatchConfig {
static int max_verify_tokens_per_batch();
static int max_spec_tree_token_num();
static int max_sequence_length();
static int get_max_tree_depth();
friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc);
void print() const;
void save_to_file(std::string const &filename) const;
Expand All @@ -50,14 +80,19 @@ class BatchConfig {
// Maximum possible values for different parameters
// These maximum values are used for copying BatchConfig
// across workers
inline static int const MAX_NUM_REQUESTS = 64;
inline static int const MAX_NUM_REQUESTS = 8;
inline static int const MAX_NUM_TOKENS = 1024;
inline static int const MAX_SPEC_TREE_TOKEN_NUM = 128;
inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 4;
inline static int const MAX_TREE_DEPTH = 16;
inline static int const MAX_TREE_WIDTH = 64;
inline static int const MAX_K_LOGITS = 16;

// The Constants for the Streaming KVCache
inline static int const SINK_SIZE = 4;
// size_SINK + size_WINDOW + depth_DRAFT shouldn't exceed this value
inline static int const MAX_STREAMING_POS = 2048;

int num_tokens = 0;
int num_available_requests = 0;
bool prompt_phase = false;
Expand All @@ -69,6 +104,7 @@ class BatchConfig {
int first_token_index_in_request = -1;
int first_token_offset_in_batch = -1;
int num_tokens_in_batch = 0;
int padding = 0; // Padding for memory pointer alignment
};

struct PerTokenInfo {
Expand Down Expand Up @@ -150,6 +186,7 @@ class BatchConfig {

BitMask causalMask[MAX_NUM_REQUESTS];
PerRequestInfo requestsInfo[MAX_NUM_REQUESTS];
StreamingCacheInfo streamingCacheInfo[MAX_NUM_REQUESTS];
PerTokenInfo tokensInfo[MAX_NUM_TOKENS];
CommittedTokensInfo committed_tokens[MAX_NUM_TOKENS];
bool request_available[MAX_NUM_REQUESTS];
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ struct FFHandler {
size_t batch_config_metadata_size =
sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) +
sizeof(BatchConfig::request_available) + sizeof(BatchConfig::causalMask) +
sizeof(BatchConfig::streamingCacheInfo) +
sizeof(BatchConfig::committed_tokens) + sizeof(int);

void *offload_reserve_space;
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention(
float scaling_factor,
bool qk_prod_scaling,
bool position_bias,
bool streaming_cache,
char const *name);

flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention(
Expand All @@ -468,6 +469,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention(
float scaling_factor,
bool qk_prod_scaling,
bool position_bias,
bool streaming_cache,
char const *name);

flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify(
Expand Down Expand Up @@ -509,6 +511,7 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention(
float scaling_factor,
bool qk_prod_scaling,
bool position_bias,
bool streaming_cache,
char const *name);

flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention(
Expand All @@ -530,6 +533,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention(
float scaling_factor,
bool qk_prod_scaling,
bool position_bias,
bool streaming_cache,
char const *name);

flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify(
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ class FFModel {
float scaling_factor = 1.0f,
bool qk_prod_scaling = true,
bool position_bias = false,
bool streaming_cache = false,
char const *name = NULL);
Tensor
spec_inc_multihead_self_attention(Tensor const input,
Expand All @@ -742,6 +743,7 @@ class FFModel {
float scaling_factor = 1.0f,
bool qk_prod_scaling = true,
bool position_bias = false,
bool streaming_cache = false,
char const *name = NULL);
Tensor inc_multihead_self_attention_verify(
Tensor const input,
Expand Down Expand Up @@ -778,6 +780,7 @@ class FFModel {
float scaling_factor = 1.0f,
bool qk_prod_scaling = true,
bool position_bias = false,
bool streaming_cache = false,
char const *name = NULL);
Tensor
spec_inc_multiquery_self_attention(Tensor const input,
Expand All @@ -797,6 +800,7 @@ class FFModel {
float scaling_factor = 1.0f,
bool qk_prod_scaling = true,
bool position_bias = false,
bool streaming_cache = false,
char const *name = NULL);
Tensor inc_multiquery_self_attention_verify(
Tensor const input,
Expand Down
18 changes: 14 additions & 4 deletions include/flexflow/ops/inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_INC_MULTIHEAD_SELF_ATTENTION_H

#include "flexflow/accessor.h"
#include "flexflow/batch_config.h"
#include "flexflow/device.h"
#include "flexflow/fftype.h"
#include "flexflow/inference.h"
Expand Down Expand Up @@ -47,6 +48,7 @@ class IncMultiHeadSelfAttention : public Op {
bool allocate_weights,
DataType _quantization_type,
bool _offload,
bool _streaming_cache,
int _tensor_parallelism_degree,
char const *name);
IncMultiHeadSelfAttention(FFModel &model,
Expand All @@ -69,6 +71,7 @@ class IncMultiHeadSelfAttention : public Op {
bool allocate_weights,
DataType _quantization_type,
bool _offload,
bool _streaming_cache,
int _tensor_parallelism_degree,
char const *name);
IncMultiHeadSelfAttention(FFModel &model,
Expand Down Expand Up @@ -131,7 +134,7 @@ class IncMultiHeadSelfAttention : public Op {
int hidden_size, qk_dim, v_dim, o_dim;
int qoSeqLength, kvSeqLength;
DataType quantization_type;
bool offload;
bool offload, streaming_cache;
};

class IncMultiHeadSelfAttentionMeta : public OpMeta {
Expand Down Expand Up @@ -165,7 +168,8 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
int _num_q_heads,
int _num_kv_heads,
DataType _quantization_type,
bool _offload);
bool _offload,
bool _streaming_cache);
~IncMultiHeadSelfAttentionMeta(void);

public:
Expand All @@ -184,14 +188,20 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta {
bool *position_bias;
float scaling_factor;
void *weight_ptr, *bias_ptr; // for weight offload
void *devQKVProjArray, *queryTmp, *kvCache;
void *devQKVProjArray, *queryTmp;
half *outputTmp;
void *qk_prods, *qk_prods_softmax;
void *kvCache;
bool streaming_cache;
// When enable Streaming cache, we alter relative position each iteration, so
// we need below memory buffer for storing the pre-pos-encoding key value in
// sink and window.
void *streamingPrePosEncBuf;
void *attn_heads;
char *quantized_weight_ptr;
BatchConfig::PerTokenInfo *token_infos;
BatchConfig::PerRequestInfo *request_infos;
bool *request_available;
StreamingCacheInfo *streaming_cache_infos;
DataType quantization_type;
bool offload;
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
Expand Down
2 changes: 1 addition & 1 deletion include/flexflow/ops/inc_multihead_self_attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct IncMultiHeadSelfAttentionParams {
bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding,
scaling_query, qk_prod_scaling, position_bias;
DataType quantization_type;
bool offload;
bool offload, streaming_cache;
char name[MAX_OPNAME];
bool is_valid(ParallelTensorShape const &) const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void pre_build_weight(IncMultiHeadSelfAttentionMeta const *m,
DataType data_type,
ffStream_t stream);

// [For the tokens in batch]
// Compute qkv projection for the tokens in the batch.
template <typename DT>
void compute_qkv(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
Expand All @@ -54,10 +56,56 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m,
DT const *bias_ptr,
ffStream_t stream);

// [For the tokens in batch]
// Apply position embedding for qk.
// Note that this is only used for tokens in the current batch.
// For other Key tokens like in streaming cache, we nned other kernel to apply
// the position embedding.
template <typename DT>
void update_qkv_cache(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);
void apply_pos_encoding_to_tokens_in_batch(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
DT *output_ptr,
cudaStream_t stream);

// [For the tokens in streaming cache]
// Apply position embedding for k projection in the streaming cache.
// Note that before the position encoding, the projection is moved *in order* to
// the kv memory took by the attention kernel. So our operation is applied where
// kvCache points to.
template <typename DT>
void apply_pos_encoding_to_streaming_proj(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);

// [For the tokens in batch]
// Update the kv cache, and compact the q array.
// Source: qkv projeciton array of tokens in the batch.
// Destination: q&kv ptr took by the attention kernel.
// Note that the q&k here are the value after applying with position encoding.
template <typename DT>
void update_qkv_in_batch(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);

// [For the tokens in streaming cache]
// Convert the out-of-order cache to in-order relative position.
// Source: pre-pos-encoding kv values in the streaming cache.
// Destination: kv ptr took by the attention kernel.
template <typename DT>
void update_kv_in_streaming_cache(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);

// [For the tokens in batch]
// Commit the kv values to the streaming cache.
// Source: qkv projeciton array of tokens in the batch.
// Destination: pre-pos-encoding kv values in the streaming cache.
template <typename DT>
void commit_kv(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
cudaStream_t stream);

template <typename DT>
void produce_output(IncMultiHeadSelfAttentionMeta const *m,
Expand Down
3 changes: 3 additions & 0 deletions include/flexflow/ops/spec_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SpecIncMultiHeadSelfAttention : public Op {
bool _qk_prod_scaling,
bool _position_bias,
bool allocate_weights,
bool _streaming_cache,
char const *name);
SpecIncMultiHeadSelfAttention(FFModel &model,
ParallelTensor const _input,
Expand All @@ -61,6 +62,7 @@ class SpecIncMultiHeadSelfAttention : public Op {
bool _qk_prod_scaling,
bool _position_bias,
bool allocate_weights,
bool _streaming_cache,
char const *name);
SpecIncMultiHeadSelfAttention(FFModel &model,
SpecIncMultiHeadSelfAttention const &other,
Expand Down Expand Up @@ -124,6 +126,7 @@ class SpecIncMultiHeadSelfAttention : public Op {
qk_prod_scaling, position_bias;
int hidden_size, qk_dim, v_dim, o_dim;
int qoSeqLength, kvSeqLength;
bool streaming_cache;
};

class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct SpecIncMultiHeadSelfAttentionParams {
float dropout, scaling_factor;
bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding,
scaling_query, qk_prod_scaling, position_bias;
bool streaming_cache;
char name[MAX_OPNAME];
bool is_valid(ParallelTensorShape const &) const;
};
Expand Down
Loading
Loading