Skip to content

Commit

Permalink
moved top_p and top_k into sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Dec 8, 2023
1 parent 35375ae commit b9afdb9
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 351 deletions.
14 changes: 7 additions & 7 deletions src/engine/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ OutputParameters Worker::execute_model(
// create and call logits processors
auto logits_processor =
LogitsProcessor::create(sampling_params, dtype_, device_);
logits = logits_processor->forward(d_params.token_ids,
d_params.token_counts,
d_params.token_ids_lens,
logits);
// apply logits processors to logits in-place
logits_processor->forward(logits,
d_params.token_ids,
d_params.token_counts,
d_params.token_ids_lens);

// create and call sampler
auto sampler = std::make_unique<Sampler>(
sampling_params.do_sample, sampling_params.seeds, device_);
auto next_tokens = sampler->sample(logits);
auto sampler = std::make_unique<Sampler>(sampling_params, dtype_, device_);
auto next_tokens = sampler->forward(logits);

// prepare output parameters
OutputParameters output_params;
Expand Down
10 changes: 10 additions & 0 deletions src/sampling/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,14 @@ cc_test(
DEPS
:logits_processor
GTest::gtest_main
)

cc_test(
NAME
sampler_test
SRCS
sampler_test.cpp
DEPS
:sampler
GTest::gtest_main
)
17 changes: 0 additions & 17 deletions src/sampling/logits_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

#include <algorithm>
#include <memory>
#include <string>

#include "common/logging.h"

namespace llm {
std::unique_ptr<LogitsProcessor> LogitsProcessor::create(
Expand Down Expand Up @@ -43,20 +40,6 @@ std::unique_ptr<LogitsProcessor> LogitsProcessor::create(
params.temperatures, dtype, device));
}

if (std::any_of(params.top_k.begin(), params.top_k.end(), [](int64_t t) {
return t != 0;
})) {
processors.push_back(
std::make_unique<TopKLogitsProcessor>(params.top_k, dtype, device));
}

if (std::any_of(params.top_p.begin(), params.top_p.end(), [](float t) {
return t != 1.0;
})) {
processors.push_back(
std::make_unique<TopPLogitsProcessor>(params.top_p, dtype, device));
}

return std::make_unique<LogitsProcessorList>(std::move(processors));
}

Expand Down
212 changes: 29 additions & 183 deletions src/sampling/logits_processor.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
#pragma once
#include <ATen/core/TensorBody.h>
#include <ATen/ops/any.h>
#include <torch/torch.h>

#include <algorithm>
#include <cstdint>
#include <limits>
#include <memory>
#include <tuple>
#include <vector>

#include "kernels/sampling/sampling_kernels.h"
Expand Down Expand Up @@ -64,8 +58,6 @@ inline void apply_frequency_presence_penalty(
// 1. frequency and presence penalty
// 2. repetition penalty
// 3. temperature
// 4. top-k
// 5. top-p

// inspired by transformers LogistProcessor:
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L44
Expand All @@ -81,15 +73,15 @@ class LogitsProcessor {
// used in frequency and presence penalty for now
// logits: [num_seqs, vocab_size]
// the logits to be processed
virtual torch::Tensor forward(const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens,
const torch::Tensor& logits) const = 0;
virtual void forward(torch::Tensor& logits,
const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens) const = 0;

// operator() allows us to use the module as a function.
template <typename... Args>
torch::Tensor operator()(Args&&... args) {
return this->forward(::std::forward<Args>(args)...);
void operator()(Args&&... args) {
this->forward(::std::forward<Args>(args)...);
}

// factory method to create a logits processor
Expand All @@ -104,16 +96,13 @@ class LogitsProcessorList : public LogitsProcessor {
LogitsProcessorList(std::vector<std::unique_ptr<LogitsProcessor>> processors)
: processors_(std::move(processors)) {}

torch::Tensor forward(const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens,
const torch::Tensor& logits) const override {
torch::Tensor output = logits;
void forward(torch::Tensor& logits,
const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens) const override {
for (const auto& processor : processors_) {
output =
processor->forward(token_ids, token_counts, token_ids_lens, output);
processor->forward(logits, token_ids, token_counts, token_ids_lens);
}
return output;
}

private:
Expand Down Expand Up @@ -141,28 +130,25 @@ class FrequencyPresencePenaltyLogitsProcessor : public LogitsProcessor {
.unsqueeze(1);
}

torch::Tensor forward(const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens,
const torch::Tensor& logits) const override {
auto logits_ = logits;
void forward(torch::Tensor& logits,
const torch::Tensor& token_ids,
const torch::Tensor& token_counts,
const torch::Tensor& token_ids_lens) const override {
if (logits.is_cuda()) {
kernel::apply_frequency_presence_penalty(logits_,
kernel::apply_frequency_presence_penalty(logits,
token_ids,
token_counts,
token_ids_lens,
frequency_penalties_,
presence_penalties_);
} else {
detail::apply_frequency_presence_penalty(logits_,
detail::apply_frequency_presence_penalty(logits,
token_ids,
token_counts,
token_ids_lens,
frequency_penalties_,
presence_penalties_);
}

return logits_;
};

private:
Expand All @@ -181,19 +167,17 @@ class RepetitionPenaltyLogitsProcessor : public LogitsProcessor {
}

// token_ids, [num_seqs, max_num_tokens] LongTensor
torch::Tensor forward(const torch::Tensor& token_ids,
const torch::Tensor& /*token_counts*/,
const torch::Tensor& token_ids_lens,
const torch::Tensor& logits) const override {
auto logits_ = logits;
void forward(torch::Tensor& logits,
const torch::Tensor& token_ids,
const torch::Tensor& /*token_counts*/,
const torch::Tensor& token_ids_lens) const override {
if (logits.is_cuda()) {
kernel::apply_repetition_penalty(
logits_, token_ids, token_ids_lens, penalties_);
logits, token_ids, token_ids_lens, penalties_);
} else {
detail::apply_repetition_penalty(
logits_, token_ids, token_ids_lens, penalties_);
logits, token_ids, token_ids_lens, penalties_);
}
return logits_;
}

private:
Expand All @@ -218,157 +202,19 @@ class TemperatureLogitsProcessor : public LogitsProcessor {
torch::where(temperatures_ == 0, torch::tensor(1.0), temperatures_);
}

torch::Tensor forward(const torch::Tensor& /*token_ids*/,
const torch::Tensor& /*token_counts*/,
const torch::Tensor& /*token_ids_lens*/,
const torch::Tensor& logits) const override {
auto logits_ = logits;
void forward(torch::Tensor& logits,
const torch::Tensor& /*token_ids*/,
const torch::Tensor& /*token_counts*/,
const torch::Tensor& /*token_ids_lens*/) const override {
if (logits.is_cuda()) {
kernel::apply_temperature_penalty(logits_, temperatures_);
kernel::apply_temperature_penalty(logits, temperatures_);
} else {
detail::apply_temperature_penalty(logits_, temperatures_);
detail::apply_temperature_penalty(logits, temperatures_);
}
return logits_;
}

private:
torch::Tensor temperatures_;
};

// TODO: move topk and topp into sampler
class TopPLogitsProcessor : public LogitsProcessor {
public:
TopPLogitsProcessor(
const std::vector<float>& top_p,
torch::ScalarType dtype,
const torch::Device& device,
float filter_value = -std::numeric_limits<float>::infinity(),
int64_t min_tokens_to_keep = 1)
: filter_value(filter_value), min_tokens_to_keep(min_tokens_to_keep) {
top_p_opposite =
1.0 -
torch::tensor(top_p, torch::dtype(dtype).device(device)).unsqueeze(1);
}

torch::Tensor forward(const torch::Tensor& /*token_ids*/,
const torch::Tensor& /*token_counts*/,
const torch::Tensor& /*token_ids_lens*/,
const torch::Tensor& logits) const override {
// sort the logits in descending order
auto [sorted_logits, sorted_indices] =
logits.sort(/*dim=*/1, /*descending=*/false);

// calculate cumulative probabilities
torch::Tensor cumulative_probs =
sorted_logits.softmax(/*dim=*/-1).cumsum(/*dim=*/-1);

// remove tokens with cumulative probability above top_p
torch::Tensor sorted_indices_to_remove = cumulative_probs <= top_p_opposite;
sorted_indices_to_remove.index_put_(
{torch::indexing::Slice(),
torch::indexing::Slice(-min_tokens_to_keep, torch::indexing::None)},
false);

// scatter the modified indices back to logits
torch::Tensor indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove);
torch::Tensor warped_scores =
logits.masked_fill_(indices_to_remove, filter_value);
return warped_scores;
}

private:
// [num_seqs, 1] FloatTensor
torch::Tensor top_p_opposite;
// the value used for filtering, all logits will be set to this value if they
// are filtered
float filter_value;
// the minimum number of tokens to keep
int64_t min_tokens_to_keep;
};

class TopKLogitsProcessor : public LogitsProcessor {
public:
// top_k: input is 1-based, 0 means no filtering or disable filtering
TopKLogitsProcessor(
const std::vector<int64_t>& top_k,
torch::ScalarType /*dtype*/,
const torch::Device& device,
float filter_value = -std::numeric_limits<float>::infinity(),
int64_t min_tokens_to_keep = 1)
: filter_value_(filter_value),
max_top_k_(*std::max_element(top_k.begin(), top_k.end())),
min_tokens_to_keep_(min_tokens_to_keep) {
std::vector<int64_t> adjusted_top_k;
// need to use int8_t for bool tensor
// std::vector<bool> is a special case in c++ stl and can't be directly
// used to initialize a bool tensor
std::vector<int8_t> disabled;
adjusted_top_k.reserve(top_k.size());
disabled.reserve(top_k.size());
for (auto val : top_k) {
// adjust top_k to be 0-based
adjusted_top_k.push_back(std::max(val, min_tokens_to_keep) - 1);
disabled.push_back(val == 0 ? 1 : 0);
}

top_k_ = torch::tensor(adjusted_top_k,
torch::dtype(torch::kInt64).device(device))
.unsqueeze(1);

if (std::any_of(disabled.begin(), disabled.end(), [](int8_t v) {
return v == 1;
})) {
top_k_disabled_mask_ =
torch::tensor(disabled, torch::dtype(torch::kBool).device(device))
.unsqueeze(1);
}
}

torch::Tensor forward(const torch::Tensor& /*token_ids*/,
const torch::Tensor& /*token_counts*/,
const torch::Tensor& /*token_ids_lens*/,
const torch::Tensor& logits) const override {
torch::Tensor top_k = top_k_;
auto max_top_k = max_top_k_;

// if max_top_k > vocab_size, then we need to clamp the top_k values
const auto vocab_size = logits.size(/*dim=*/-1);
if (vocab_size < max_top_k_) {
max_top_k = vocab_size;
// adjust top_k to be 0-based
top_k = top_k.clamp_max(/*max=*/vocab_size - 1);
}

// get the kth score for each sequence
auto [topk_scores, _] = logits.topk(/*k=*/max_top_k);
torch::Tensor kth_scores = topk_scores.gather(/*dim=*/1, /*index=*/top_k);

if (top_k_disabled_mask_.defined()) {
// use a very low value for the top-k scores to disable filtering
kth_scores.masked_fill_(/*mask=*/top_k_disabled_mask_,
/*value=*/filter_value_);
}

// 'remove' tokens with logits < kth score by setting them to filter_value
torch::Tensor indices_to_remove = logits < kth_scores;
logits.masked_fill_(/*mask=*/indices_to_remove, /*value=*/filter_value_);
return logits;
}

private:
// [num_seqs, 1] IntTensor 0-based
torch::Tensor top_k_;
// [num_seqs, 1] BoolTensor to disable top_k filtering for some sequences with
// top_k = 0
torch::Tensor top_k_disabled_mask_;
// the value used for filtering, all logits will be set to this value if they
// are filtered
float filter_value_;
// the maximum value of top_k
int64_t max_top_k_;
// the minimum number of tokens to keep
int64_t min_tokens_to_keep_;
};

} // namespace llm
Loading

0 comments on commit b9afdb9

Please sign in to comment.