From 7735b24301aa98caf14d9428a3b4c143365316cf Mon Sep 17 00:00:00 2001 From: Tongxuan Liu Date: Fri, 22 Dec 2023 11:09:30 +0800 Subject: [PATCH] [refactor] rename Executor to ThreadPool. (#36) --- src/common/CMakeLists.txt | 4 +-- src/common/{executor.cpp => threadpool.cpp} | 10 +++---- src/common/{executor.h => threadpool.h} | 18 ++++++------ ...{executor_test.cpp => threadpool_test.cpp} | 28 +++++++++---------- src/engine/worker.cpp | 10 +++---- src/engine/worker.h | 4 +-- .../continuous_batching_scheduler.cpp | 6 ++-- src/scheduler/continuous_batching_scheduler.h | 4 +-- src/server/handlers/chat_handler.cpp | 2 +- src/server/handlers/chat_handler.h | 6 ++-- src/server/handlers/completion_handler.cpp | 2 +- src/server/handlers/completion_handler.h | 6 ++-- 12 files changed, 50 insertions(+), 50 deletions(-) rename src/common/{executor.cpp => threadpool.cpp} (80%) rename src/common/{executor.h => threadpool.h} (58%) rename src/common/{executor_test.cpp => threadpool_test.cpp} (70%) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 9aa9bcee..2ef12b75 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -9,11 +9,11 @@ cc_library( metrics.h slice.h concurrent_queue.h - executor.h + threadpool.h pretty_print.h json_reader.h SRCS - executor.cpp + threadpool.cpp pretty_print.cpp json_reader.cpp DEPS diff --git a/src/common/executor.cpp b/src/common/threadpool.cpp similarity index 80% rename from src/common/executor.cpp rename to src/common/threadpool.cpp index 295803ec..168fb26d 100644 --- a/src/common/executor.cpp +++ b/src/common/threadpool.cpp @@ -1,4 +1,4 @@ -#include "executor.h" +#include "threadpool.h" #include #include @@ -6,13 +6,13 @@ #include "concurrent_queue.h" namespace llm { -Executor::Executor(size_t num_threads) { +ThreadPool::ThreadPool(size_t num_threads) { for (size_t i = 0; i < num_threads; ++i) { threads_.emplace_back([this]() { internal_loop(); }); } } -Executor::~Executor() { +ThreadPool::~ThreadPool() { // push nullptr to the queue to signal threads to exit for (size_t i = 0; i < threads_.size(); ++i) { queue_.push(nullptr); @@ -24,14 +24,14 @@ Executor::~Executor() { } // schedule a runnable to be executed -void Executor::schedule(Runnable runnable) { +void ThreadPool::schedule(Runnable runnable) { if (runnable == nullptr) { return; } queue_.push(std::move(runnable)); } -void Executor::internal_loop() { +void ThreadPool::internal_loop() { while (true) { Runnable runnable = queue_.pop(); if (runnable == nullptr) { diff --git a/src/common/executor.h b/src/common/threadpool.h similarity index 58% rename from src/common/executor.h rename to src/common/threadpool.h index 9ee8aee8..516df73a 100644 --- a/src/common/executor.h +++ b/src/common/threadpool.h @@ -6,25 +6,25 @@ namespace llm { -class Executor final { +class ThreadPool final { public: - // a runnable is an object intended to be executed by the executor + // a runnable is an object intended to be executed by the threadpool // it must be invokable with no arguments and return void. using Runnable = folly::Function; // constructors - Executor() : Executor(1) {} + ThreadPool() : ThreadPool(1) {} // disable copy/move constructor and assignment - Executor(const Executor&) = delete; - Executor& operator=(const Executor&) = delete; - Executor(Executor&&) = delete; - Executor& operator=(Executor&&) = delete; + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; + ThreadPool(ThreadPool&&) = delete; + ThreadPool& operator=(ThreadPool&&) = delete; - explicit Executor(size_t num_threads); + explicit ThreadPool(size_t num_threads); // destructor - ~Executor(); + ~ThreadPool(); // schedule a runnable to be executed void schedule(Runnable runnable); diff --git a/src/common/executor_test.cpp b/src/common/threadpool_test.cpp similarity index 70% rename from src/common/executor_test.cpp rename to src/common/threadpool_test.cpp index ca123728..b1ad3ff0 100644 --- a/src/common/executor_test.cpp +++ b/src/common/threadpool_test.cpp @@ -1,4 +1,4 @@ -#include "executor.h" +#include "threadpool.h" #include #include @@ -6,17 +6,17 @@ namespace llm { -TEST(ExecutorTest, ScheduleEmptyTask) { - Executor executor(1); +TEST(ThreadPoolTest, ScheduleEmptyTask) { + ThreadPool threadpool(1); absl::Notification notification; - executor.schedule(nullptr); + threadpool.schedule(nullptr); } -TEST(ExecutorTest, ScheduleTask) { - Executor executor(1); +TEST(ThreadPoolTest, ScheduleTask) { + ThreadPool threadpool(1); absl::Notification notification; bool called = false; - executor.schedule([&called, ¬ification]() { + threadpool.schedule([&called, ¬ification]() { called = true; notification.Notify(); }); @@ -24,13 +24,13 @@ TEST(ExecutorTest, ScheduleTask) { EXPECT_TRUE(called); } -TEST(ExecutorTest, ScheduleMultipleTasks) { - Executor executor(1); +TEST(ThreadPoolTest, ScheduleMultipleTasks) { + ThreadPool threadpool(1); std::vector completed_tasks; absl::Notification notification; // run frist task - executor.schedule([&completed_tasks, ¬ification]() { + threadpool.schedule([&completed_tasks, ¬ification]() { completed_tasks.emplace_back("first"); if (completed_tasks.size() == 2) { absl::SleepFor(absl::Milliseconds(100)); @@ -39,7 +39,7 @@ TEST(ExecutorTest, ScheduleMultipleTasks) { }); // run second task - executor.schedule([&completed_tasks, ¬ification]() { + threadpool.schedule([&completed_tasks, ¬ification]() { completed_tasks.emplace_back("second"); if (completed_tasks.size() == 2) { notification.Notify(); @@ -52,13 +52,13 @@ TEST(ExecutorTest, ScheduleMultipleTasks) { EXPECT_EQ(completed_tasks[1], "second"); } -TEST(ExecutorTest, MultipleThreads) { - Executor executor(4); +TEST(ThreadPoolTest, MultipleThreads) { + ThreadPool threadpool(4); std::atomic_uint32_t counter = 0; absl::Notification notification; for (int i = 0; i < 10; ++i) { - executor.schedule([&counter, ¬ification]() { + threadpool.schedule([&counter, ¬ification]() { absl::SleepFor(absl::Milliseconds(100)); counter++; if (counter == 10) { diff --git a/src/engine/worker.cpp b/src/engine/worker.cpp index aae29075..39ecbe8b 100644 --- a/src/engine/worker.cpp +++ b/src/engine/worker.cpp @@ -9,7 +9,7 @@ #include #include -#include "common/executor.h" +#include "common/threadpool.h" #include "common/logging.h" #include "model_loader/state_dict.h" #include "models/input_parameters.h" @@ -101,7 +101,7 @@ folly::SemiFuture Worker::execute_model_async( const SamplingParameters& sampling_params) { folly::Promise promise; auto future = promise.getSemiFuture(); - executor_.schedule([this, + threadpool_.schedule([this, tokens = flatten_tokens, positions = flatten_positions, parameters = params, @@ -121,7 +121,7 @@ folly::SemiFuture Worker::init_model_async(torch::ScalarType dtype, const QuantArgs& quant_args) { folly::Promise promise; auto future = promise.getSemiFuture(); - executor_.schedule([this, + threadpool_.schedule([this, dtype, &args, &quant_args, @@ -137,7 +137,7 @@ folly::SemiFuture Worker::init_kv_cache_async( const std::vector& value_cache_shape) { folly::Promise promise; auto future = promise.getSemiFuture(); - executor_.schedule([this, + threadpool_.schedule([this, &key_cache_shape, &value_cache_shape, promise = std::move(promise)]() mutable { @@ -152,7 +152,7 @@ folly::SemiFuture Worker::load_state_dict_async( const StateDict& state_dict) { folly::Promise promise; auto future = promise.getSemiFuture(); - executor_.schedule( + threadpool_.schedule( [this, &state_dict, promise = std::move(promise)]() mutable { // load the model weights from state_dict within the working thread this->load_state_dict(state_dict); diff --git a/src/engine/worker.h b/src/engine/worker.h index 501cf61c..ab0fefef 100644 --- a/src/engine/worker.h +++ b/src/engine/worker.h @@ -3,7 +3,7 @@ #include #include -#include "common/executor.h" +#include "common/threadpool.h" #include "model_loader/state_dict.h" #include "models/args.h" #include "models/causal_lm.h" @@ -83,7 +83,7 @@ class Worker final { private: // working thread - Executor executor_; + ThreadPool threadpool_; // dtype of the model torch::ScalarType dtype_; diff --git a/src/scheduler/continuous_batching_scheduler.cpp b/src/scheduler/continuous_batching_scheduler.cpp index 0a17d29e..c4f7b8f5 100644 --- a/src/scheduler/continuous_batching_scheduler.cpp +++ b/src/scheduler/continuous_batching_scheduler.cpp @@ -59,8 +59,8 @@ void ContinuousBatchingScheduler::on_request_finish(Request* request) { block_manager_->release_slots_for_request(request); // take over the ownership of the request std::unique_ptr finished_request(request); - response_executor_.schedule([tokenizer = tokenizer_.get(), - request = std::move(finished_request)]() { + response_threadpool_.schedule([tokenizer = tokenizer_.get(), + request = std::move(finished_request)]() { if (request->stream) { // just finish the request request->on_finish("", FinishReason::NONE, Status()); @@ -83,7 +83,7 @@ void ContinuousBatchingScheduler::on_sequence_stream(Sequence* seq) { num_tokens_to_output >= FLAGS_streaming_token_buffer_size) { const auto finish_reason = seq->finish_reason(); // output the delta text til the end of the sequence to the client - response_executor_.schedule( + response_threadpool_.schedule( [seq, tokenizer = tokenizer_.get(), end = num_tokens, finish_reason]() { const auto detla = seq->decode_delta_text(end, *tokenizer); if (!detla.empty() || finish_reason != FinishReason::NONE) { diff --git a/src/scheduler/continuous_batching_scheduler.h b/src/scheduler/continuous_batching_scheduler.h index ea072d6b..052283d0 100644 --- a/src/scheduler/continuous_batching_scheduler.h +++ b/src/scheduler/continuous_batching_scheduler.h @@ -66,8 +66,8 @@ class ContinuousBatchingScheduler final : public Scheduler { // low. std::deque preemptable_candidates_; - // the executor to handle responses - Executor response_executor_; + // the threadpool to handle responses + ThreadPool response_threadpool_; }; } // namespace llm diff --git a/src/server/handlers/chat_handler.cpp b/src/server/handlers/chat_handler.cpp index 6126dc83..e04a899b 100644 --- a/src/server/handlers/chat_handler.cpp +++ b/src/server/handlers/chat_handler.cpp @@ -258,7 +258,7 @@ ChatHandler::ChatHandler(Scheduler* scheduler, const Engine* engine) } void ChatHandler::chat_async(ChatCallData* call_data) { - converter_executor_.schedule([this, call_data = call_data]() { + converter_threadpool_.schedule([this, call_data = call_data]() { if (!verify_request_arguments(call_data)) { // request is not valid, finish with error return; diff --git a/src/server/handlers/chat_handler.h b/src/server/handlers/chat_handler.h index 3ec4ae5e..4da7e5a6 100644 --- a/src/server/handlers/chat_handler.h +++ b/src/server/handlers/chat_handler.h @@ -5,7 +5,7 @@ #include "server/call_data.h" #include "chat.grpc.pb.h" -#include "common/executor.h" +#include "common/threadpool.h" #include "engine/engine.h" #include "models/args.h" #include "scheduler/scheduler.h" @@ -31,8 +31,8 @@ class ChatHandler final { // model args ModelArgs model_args_; - // converter executor - Executor converter_executor_; + // converter threadpool + ThreadPool converter_threadpool_; }; } // namespace llm diff --git a/src/server/handlers/completion_handler.cpp b/src/server/handlers/completion_handler.cpp index d376a870..ede95542 100644 --- a/src/server/handlers/completion_handler.cpp +++ b/src/server/handlers/completion_handler.cpp @@ -238,7 +238,7 @@ CompletionHandler::CompletionHandler(Scheduler* scheduler, const Engine* engine) } void CompletionHandler::complete_async(CompletionCallData* call_data) { - converter_executor_.schedule([this, call_data = call_data]() { + converter_threadpool_.schedule([this, call_data = call_data]() { if (!verify_request_arguments(call_data)) { // request is not valid, finish with error return; diff --git a/src/server/handlers/completion_handler.h b/src/server/handlers/completion_handler.h index c6509678..769ce913 100644 --- a/src/server/handlers/completion_handler.h +++ b/src/server/handlers/completion_handler.h @@ -4,7 +4,7 @@ #include #include "server/call_data.h" -#include "common/executor.h" +#include "common/threadpool.h" #include "completion.grpc.pb.h" #include "engine/engine.h" #include "models/args.h" @@ -32,8 +32,8 @@ class CompletionHandler final { // model args ModelArgs model_args_; - // converter executor - Executor converter_executor_; + // converter threadpool + ThreadPool converter_threadpool_; }; } // namespace llm