Skip to content

Commit

Permalink
enable logits processor kernels for FrequencyPresencePenalty and Repe…
Browse files Browse the repository at this point in the history
…titionPenalty.
  • Loading branch information
guocuimi committed Dec 7, 2023
1 parent f9755e7 commit 9257567
Show file tree
Hide file tree
Showing 12 changed files with 303 additions and 108 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ find_package(Python REQUIRED COMPONENTS Development)

find_package(Jemalloc)
if(Jemalloc_FOUND)
message(STATUS "jemalloc found, linking to jemalloc")
link_libraries(Jemalloc::jemalloc)
endif()

Expand Down
4 changes: 2 additions & 2 deletions cmake/FindJemalloc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ if(NOT Jemalloc_FOUND)
if(Jemalloc_FIND_REQUIRED)
message(FATAL_ERROR "Cannot find jemalloc!")
else()
message(STATUS "jemalloc is not found!")
message(STATUS "Jemalloc is not found!")
endif()
else()
if(Jemalloc_FOUND AND NOT TARGET Jemalloc::jemalloc)
Expand All @@ -56,6 +56,6 @@ else()
INTERFACE_INCLUDE_DIRECTORIES "${Jemalloc_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "m;stdc++;Threads::Threads;dl"
)
message(STATUS "Using jemalloc: " ${Jemalloc_LIBRARY})
message(STATUS "Using Jemalloc: " ${Jemalloc_LIBRARY})
endif()
endif()
20 changes: 13 additions & 7 deletions src/engine/utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "utils.h"

#include <torch/torch.h>
#include <torch/types.h>

#include <vector>

Expand Down Expand Up @@ -58,7 +59,8 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
std::vector<int32_t> last_token_idxes;

// track the token ids and counts in the batch
std::vector<std::vector<int32_t>> token_ids_vec;
std::vector<std::vector<int64_t>> token_ids_vec;
std::vector<int32_t> token_ids_lens_vec;
std::vector<std::vector<int32_t>> token_counts_vec;
size_t max_unique_tokens = 0;

Expand Down Expand Up @@ -88,8 +90,8 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
flatten_positions_vec.push_back(i);
}

std::vector<int32_t>& ids = token_ids_vec.emplace_back();
std::vector<int32_t>& counts = token_counts_vec.emplace_back();
auto& ids = token_ids_vec.emplace_back();
auto& counts = token_counts_vec.emplace_back();

const auto& seq_token_counts = sequence->token_counts();
const auto unique_tokens = seq_token_counts.size();
Expand All @@ -99,6 +101,7 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
ids.push_back(token_id);
counts.push_back(count);
}
token_ids_lens_vec.push_back(static_cast<int32_t>(unique_tokens));
max_unique_tokens = std::max(max_unique_tokens, unique_tokens);

num_prompt_tokens += num_tokens;
Expand Down Expand Up @@ -130,8 +133,8 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
flatten_tokens_vec.push_back(seq_token_ids.back());
flatten_positions_vec.push_back(num_tokens - 1);

std::vector<int32_t>& ids = token_ids_vec.emplace_back();
std::vector<int32_t>& counts = token_counts_vec.emplace_back();
auto& ids = token_ids_vec.emplace_back();
auto& counts = token_counts_vec.emplace_back();

const auto& seq_token_counts = sequence->token_counts();
const auto unique_tokens = seq_token_counts.size();
Expand All @@ -141,6 +144,7 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
ids.push_back(token_id);
counts.push_back(count);
}
token_ids_lens_vec.push_back(static_cast<int32_t>(unique_tokens));
max_unique_tokens = std::max(max_unique_tokens, unique_tokens);

context_lens.push_back(num_tokens);
Expand All @@ -160,7 +164,7 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
auto token_ids_tensor =
torch::empty({static_cast<int64_t>(token_ids_vec.size()),
static_cast<int64_t>(max_unique_tokens)},
torch::kLong);
torch::kInt64);
auto token_counts_tensor =
torch::empty({static_cast<int64_t>(token_counts_vec.size()),
static_cast<int64_t>(max_unique_tokens)},
Expand All @@ -169,7 +173,8 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
auto& ids = token_ids_vec[i];
// padding token ids to the same length
ids.resize(max_unique_tokens, /*pad_id=*/0);
token_ids_tensor.index_put_({i, Slice()}, torch::tensor(ids, torch::kLong));
token_ids_tensor.index_put_({i, Slice()},
torch::tensor(ids, torch::kInt64));

auto& counts = token_counts_vec[i];
counts.resize(max_unique_tokens, /*pad_id=*/0);
Expand Down Expand Up @@ -200,6 +205,7 @@ void Utils::prepare_inputs(const std::vector<Sequence*>& batch,
input_params->last_token_idxes = torch::tensor(last_token_idxes, torch::kInt);
input_params->token_ids = token_ids_tensor;
input_params->token_counts = token_counts_tensor;
input_params->token_ids_lens = torch::tensor(token_ids_lens_vec, torch::kInt);
}

} // namespace llm
6 changes: 4 additions & 2 deletions src/engine/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ OutputParameters Worker::execute_model(
// create and call logits processors
auto logits_processor =
LogitsProcessor::create(sampling_params, dtype_, device_);
logits = logits_processor->forward(
d_params.token_ids, d_params.token_counts, logits);
logits = logits_processor->forward(d_params.token_ids,
d_params.token_counts,
d_params.token_ids_lens,
logits);

// create and call sampler
auto sampler = std::make_unique<Sampler>(
Expand Down
9 changes: 5 additions & 4 deletions src/engine/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ class Worker final {
const std::vector<int64_t>& value_cache_shape);

// Run the model on the given input. blocking call
OutputParameters execute_model(torch::Tensor flatten_tokens, // [num_tokens]
torch::Tensor flatten_positions, // [num_tokens]
const InputParameters& params,
const SamplingParameters& sampling_params);
OutputParameters execute_model(
torch::Tensor flatten_tokens, // [num_tokens]
torch::Tensor flatten_positions, // [num_tokens]
const InputParameters& params,
const SamplingParameters& sampling_params);

// initialize model, cache manager. async call
folly::SemiFuture<bool> init_model_async(torch::ScalarType dtype,
Expand Down
66 changes: 33 additions & 33 deletions src/kernels/sampling/penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,32 +57,32 @@ void apply_temperature_penalty(torch::Tensor& logits,
template <typename T>
__global__ void apply_repetition_penalty_kernel(
T* __restrict__ logits,
const int* __restrict__ token_ids,
const long* __restrict__ token_ids,
const int* __restrict__ token_ids_lens,
const T* __restrict__ penalities,
const int* __restrict__ seq_lens,
int max_seq_len,
int vocab_size) {
const int tid = threadIdx.x;
// batch idx
const int bid = blockIdx.x;
const float penalty = penalities[bid];
const int seq_len = seq_lens[bid];
const int len = token_ids_lens[bid];
// move the pointer to the start of the batch
logits += bid * vocab_size;

for (int i = tid; i < seq_len; i += blockDim.x) {
const int token_id = token_ids[bid * max_seq_len + i];
for (int i = tid; i < len; i += blockDim.x) {
const long token_id = token_ids[bid * max_seq_len + i];
const float logit = logits[token_id];
assert(token_id < vocab_size);
// assert(token_id < vocab_size);
// apply repetition penalty
logits[token_id] = logit < 0.0f ? logit * penalty : logit / penalty;
}
}

void apply_repetition_penalty(torch::Tensor& logits,
torch::Tensor token_ids,
torch::Tensor seq_lens,
torch::Tensor penalities) {
const torch::Tensor& token_ids,
const torch::Tensor& token_ids_lens,
const torch::Tensor& penalities) {
DCHECK(logits.is_contiguous()) << "logits tensor must be contiguous";
DCHECK(token_ids.is_contiguous()) << "token_ids tensor must be contiguous";
DCHECK(penalities.is_contiguous()) << "penalities tensor must be contiguous";
Expand All @@ -102,9 +102,9 @@ void apply_repetition_penalty(torch::Tensor& logits,
apply_repetition_penalty_kernel<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
logits.data_ptr<scalar_t>(),
token_ids.data_ptr<int>(),
token_ids.data_ptr<long>(),
token_ids_lens.data_ptr<int>(),
penalities.data_ptr<scalar_t>(),
seq_lens.data_ptr<int>(),
max_seq_len,
vocab_size);
});
Expand All @@ -113,46 +113,46 @@ void apply_repetition_penalty(torch::Tensor& logits,
template <typename T>
__global__ void apply_frequency_presence_penalty_kernel(
T* __restrict__ logits,
const int* __restrict__ token_ids,
const long* __restrict__ token_ids,
const int* __restrict__ token_counts,
const T* __restrict__ frequency_penalities,
const T* __restrict__ presence_penalities,
const int* __restrict__ seq_lens,
const int* __restrict__ token_ids_lens,
const T* __restrict__ frequency_penalties,
const T* __restrict__ presence_penalties,
int max_seq_len,
int vocab_size) {
const int tid = threadIdx.x;
// batch idx
const int bid = blockIdx.x;
const int seq_len = seq_lens ? seq_lens[bid] : max_seq_len;
const int len = token_ids_lens[bid];
// move the pointer to the start of the batch
logits += bid * vocab_size;

for (int i = tid; i < seq_len; i += blockDim.x) {
for (int i = tid; i < len; i += blockDim.x) {
const int idx = bid * max_seq_len + i;
const int token_id = token_ids[idx];
const long token_id = token_ids[idx];
const int token_count = token_counts[idx];
assert(token_id < vocab_size);
// assert(token_id < vocab_size);
if (token_count > 0) {
// apply frequency and presence penalities
// apply frequency then presence penalities
float logit = logits[token_id];
logit -= (token_count * (float)frequency_penalities[bid]);
logit -= presence_penalities[bid];
logit -= (token_count * (float)frequency_penalties[bid]);
logit -= presence_penalties[bid];
logits[token_id] = logit;
}
}
}

void apply_frequency_presence_penalty(torch::Tensor& logits,
torch::Tensor token_ids,
torch::Tensor token_counts,
torch::Tensor seq_lens,
torch::Tensor frequency_penalities,
torch::Tensor presence_penalities) {
const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens,
const torch::Tensor& frequency_penalties,
const torch::Tensor& presence_penalties) {
DCHECK(logits.is_contiguous()) << "logits tensor must be contiguous";
DCHECK(token_ids.is_contiguous()) << "token_ids tensor must be contiguous";
DCHECK(frequency_penalities.is_contiguous())
DCHECK(frequency_penalties.is_contiguous())
<< "penalities tensor must be contiguous";
DCHECK(presence_penalities.is_contiguous())
DCHECK(presence_penalties.is_contiguous())
<< "penalities tensor must be contiguous";
DCHECK(logits.size(0) == token_ids.size(0))
<< "logits and token_ids must have the same batch size";
Expand All @@ -171,11 +171,11 @@ void apply_frequency_presence_penalty(torch::Tensor& logits,
apply_frequency_presence_penalty_kernel<scalar_t>
<<<grid, block, smem_size, at::cuda::getCurrentCUDAStream()>>>(
logits.data_ptr<scalar_t>(),
token_ids.data_ptr<int>(),
token_ids.data_ptr<long>(),
token_counts.data_ptr<int>(),
frequency_penalities.data_ptr<scalar_t>(),
presence_penalities.data_ptr<scalar_t>(),
seq_lens.defined() ? seq_lens.data_ptr<int>() : nullptr,
token_ids_lens.data_ptr<int>(),
frequency_penalties.data_ptr<scalar_t>(),
presence_penalties.data_ptr<scalar_t>(),
max_seq_len,
vocab_size);
});
Expand Down
16 changes: 8 additions & 8 deletions src/kernels/sampling/sampling_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ void apply_temperature_penalty(torch::Tensor& logits,
// token_ids are unique token ids for each sequence.
// the order of token ids does not matter.
void apply_repetition_penalty(torch::Tensor& logits,
torch::Tensor token_ids,
torch::Tensor seq_lens,
torch::Tensor penalities);
const torch::Tensor& token_ids,
const torch::Tensor& token_ids_lens,
const torch::Tensor& penalities);

// token_ids are unique token ids for each sequence.
// token_counts are the number of times corresponding token appears in the
// sequence.
void apply_frequency_presence_penalty(torch::Tensor& logits,
torch::Tensor token_ids,
torch::Tensor token_counts,
torch::Tensor seq_lens,
torch::Tensor frequency_penalities,
torch::Tensor presence_penalities);
const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens,
const torch::Tensor& frequency_penalties,
const torch::Tensor& presence_penalties);

// calculate softmax in place
void invoke_softmax(torch::Tensor& logits);
Expand Down
4 changes: 2 additions & 2 deletions src/layers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ torch::Tensor prepare_kv_head_mapping(int64_t n_heads,
int64_t n_kv_heads,
const torch::Device& device) {
// prepare kv_head_mapping
auto kv_head_mapping = torch::arange(
0, n_kv_heads, torch::TensorOptions().dtype(torch::kInt).device(device));
auto kv_head_mapping =
torch::arange(0, n_kv_heads, torch::dtype(torch::kInt).device(device));
const auto num_group = n_heads / n_kv_heads;
if (num_group > 1) {
kv_head_mapping = kv_head_mapping.repeat_interleave(/*repeats=*/num_group);
Expand Down
25 changes: 12 additions & 13 deletions src/memory/kv_cache_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ TEST(KVCacheTest, Basic) {

// set key and value cache for the given slot_ids
for (int32_t i = 0; i < num_blocks * block_size; ++i) {
torch::Tensor slot_ids = torch::tensor(
{i}, torch::TensorOptions().dtype(torch::kInt).device(device));
torch::Tensor slot_ids =
torch::tensor({i}, torch::dtype(torch::kInt).device(device));
torch::Tensor keys =
torch::ones({1, num_kv_heads, head_dim}, /*device=*/device) * i;
torch::Tensor values =
Expand All @@ -39,8 +39,8 @@ TEST(KVCacheTest, Basic) {

// get key and value cache for the given slot_ids
for (int32_t i = 0; i < num_blocks * block_size; ++i) {
torch::Tensor slot_ids = torch::tensor(
{i}, torch::TensorOptions().dtype(torch::kInt).device(device));
torch::Tensor slot_ids =
torch::tensor({i}, torch::dtype(torch::kInt).device(device));
auto [keys, values] = kv_cache.get_kv_cache(slot_ids);
auto desired =
torch::ones({1, num_kv_heads, head_dim}, /*device=*/device) * i;
Expand All @@ -65,10 +65,10 @@ TEST(KVCacheTest, Random) {

torch::Tensor key_cache =
torch::rand({num_blocks, num_kv_heads, head_dim / x, block_size, x},
/*device=*/device);
/*device=*/device);
torch::Tensor value_cache =
torch::rand({num_blocks, num_kv_heads, head_dim, block_size},
/*device=*/device);
/*device=*/device);

KVCache kv_cache(key_cache, value_cache);

Expand All @@ -78,15 +78,14 @@ TEST(KVCacheTest, Random) {
const int sample_size = std::min(num_blocks * block_size, 10);
const int num_slots = i % sample_size + 1;
torch::Tensor slot_ids =
torch::randperm(
num_blocks * block_size,
torch::TensorOptions().dtype(torch::kInt).device(device))
torch::randperm(num_blocks * block_size,
torch::dtype(torch::kInt).device(device))
.index({Slice(0, num_slots)});

torch::Tensor keys = torch::rand({num_slots, num_kv_heads, head_dim},
torch::TensorOptions().device(device));
torch::Tensor values = torch::rand({num_slots, num_kv_heads, head_dim},
torch::TensorOptions().device(device));
torch::Tensor keys =
torch::rand({num_slots, num_kv_heads, head_dim}, torch::device(device));
torch::Tensor values =
torch::rand({num_slots, num_kv_heads, head_dim}, torch::device(device));

kv_cache.set_kv_cache_cuda(slot_ids, keys, values);

Expand Down
7 changes: 6 additions & 1 deletion src/models/input_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct InputParameters {
params.last_token_idxes = last_token_idxes.to(device);
params.token_ids = token_ids.to(device);
params.token_counts = token_counts.to(device);
params.token_ids_lens = token_ids_lens.to(device);
return params;
}

Expand Down Expand Up @@ -78,9 +79,13 @@ struct InputParameters {
// [num_seq, max_unique_tokens] LongTensor
torch::Tensor token_ids;

// the count of each token in the prompt.
// the count of each token in each sequence.
// [num_seq, max_unique_tokens] IntTensor
torch::Tensor token_counts;

// the number of unique tokens in each sequence.
// [num_seq] IntTensor
torch::Tensor token_ids_lens;
};

} // namespace llm
Loading

0 comments on commit 9257567

Please sign in to comment.