Skip to content

Commit

Permalink
[refactor] rename Executor to ThreadPool. (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
liutongxuan authored Dec 22, 2023
1 parent f99cf5e commit 7735b24
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ cc_library(
metrics.h
slice.h
concurrent_queue.h
executor.h
threadpool.h
pretty_print.h
json_reader.h
SRCS
executor.cpp
threadpool.cpp
pretty_print.cpp
json_reader.cpp
DEPS
Expand Down
10 changes: 5 additions & 5 deletions src/common/executor.cpp → src/common/threadpool.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#include "executor.h"
#include "threadpool.h"

#include <functional>
#include <thread>

#include "concurrent_queue.h"

namespace llm {
Executor::Executor(size_t num_threads) {
ThreadPool::ThreadPool(size_t num_threads) {
for (size_t i = 0; i < num_threads; ++i) {
threads_.emplace_back([this]() { internal_loop(); });
}
}

Executor::~Executor() {
ThreadPool::~ThreadPool() {
// push nullptr to the queue to signal threads to exit
for (size_t i = 0; i < threads_.size(); ++i) {
queue_.push(nullptr);
Expand All @@ -24,14 +24,14 @@ Executor::~Executor() {
}

// schedule a runnable to be executed
void Executor::schedule(Runnable runnable) {
void ThreadPool::schedule(Runnable runnable) {
if (runnable == nullptr) {
return;
}
queue_.push(std::move(runnable));
}

void Executor::internal_loop() {
void ThreadPool::internal_loop() {
while (true) {
Runnable runnable = queue_.pop();
if (runnable == nullptr) {
Expand Down
18 changes: 9 additions & 9 deletions src/common/executor.h → src/common/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@

namespace llm {

class Executor final {
class ThreadPool final {
public:
// a runnable is an object intended to be executed by the executor
// a runnable is an object intended to be executed by the threadpool
// it must be invokable with no arguments and return void.
using Runnable = folly::Function<void()>;

// constructors
Executor() : Executor(1) {}
ThreadPool() : ThreadPool(1) {}

// disable copy/move constructor and assignment
Executor(const Executor&) = delete;
Executor& operator=(const Executor&) = delete;
Executor(Executor&&) = delete;
Executor& operator=(Executor&&) = delete;
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
ThreadPool(ThreadPool&&) = delete;
ThreadPool& operator=(ThreadPool&&) = delete;

explicit Executor(size_t num_threads);
explicit ThreadPool(size_t num_threads);

// destructor
~Executor();
~ThreadPool();

// schedule a runnable to be executed
void schedule(Runnable runnable);
Expand Down
28 changes: 14 additions & 14 deletions src/common/executor_test.cpp → src/common/threadpool_test.cpp
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
#include "executor.h"
#include "threadpool.h"

#include <absl/synchronization/notification.h>
#include <absl/time/clock.h>
#include <gtest/gtest.h>

namespace llm {

TEST(ExecutorTest, ScheduleEmptyTask) {
Executor executor(1);
TEST(ThreadPoolTest, ScheduleEmptyTask) {
ThreadPool threadpool(1);
absl::Notification notification;
executor.schedule(nullptr);
threadpool.schedule(nullptr);
}

TEST(ExecutorTest, ScheduleTask) {
Executor executor(1);
TEST(ThreadPoolTest, ScheduleTask) {
ThreadPool threadpool(1);
absl::Notification notification;
bool called = false;
executor.schedule([&called, &notification]() {
threadpool.schedule([&called, &notification]() {
called = true;
notification.Notify();
});
notification.WaitForNotification();
EXPECT_TRUE(called);
}

TEST(ExecutorTest, ScheduleMultipleTasks) {
Executor executor(1);
TEST(ThreadPoolTest, ScheduleMultipleTasks) {
ThreadPool threadpool(1);
std::vector<std::string> completed_tasks;
absl::Notification notification;

// run frist task
executor.schedule([&completed_tasks, &notification]() {
threadpool.schedule([&completed_tasks, &notification]() {
completed_tasks.emplace_back("first");
if (completed_tasks.size() == 2) {
absl::SleepFor(absl::Milliseconds(100));
Expand All @@ -39,7 +39,7 @@ TEST(ExecutorTest, ScheduleMultipleTasks) {
});

// run second task
executor.schedule([&completed_tasks, &notification]() {
threadpool.schedule([&completed_tasks, &notification]() {
completed_tasks.emplace_back("second");
if (completed_tasks.size() == 2) {
notification.Notify();
Expand All @@ -52,13 +52,13 @@ TEST(ExecutorTest, ScheduleMultipleTasks) {
EXPECT_EQ(completed_tasks[1], "second");
}

TEST(ExecutorTest, MultipleThreads) {
Executor executor(4);
TEST(ThreadPoolTest, MultipleThreads) {
ThreadPool threadpool(4);
std::atomic_uint32_t counter = 0;
absl::Notification notification;

for (int i = 0; i < 10; ++i) {
executor.schedule([&counter, &notification]() {
threadpool.schedule([&counter, &notification]() {
absl::SleepFor(absl::Milliseconds(100));
counter++;
if (counter == 10) {
Expand Down
10 changes: 5 additions & 5 deletions src/engine/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <memory>
#include <utility>

#include "common/executor.h"
#include "common/threadpool.h"
#include "common/logging.h"
#include "model_loader/state_dict.h"
#include "models/input_parameters.h"
Expand Down Expand Up @@ -101,7 +101,7 @@ folly::SemiFuture<OutputParameters> Worker::execute_model_async(
const SamplingParameters& sampling_params) {
folly::Promise<OutputParameters> promise;
auto future = promise.getSemiFuture();
executor_.schedule([this,
threadpool_.schedule([this,
tokens = flatten_tokens,
positions = flatten_positions,
parameters = params,
Expand All @@ -121,7 +121,7 @@ folly::SemiFuture<bool> Worker::init_model_async(torch::ScalarType dtype,
const QuantArgs& quant_args) {
folly::Promise<bool> promise;
auto future = promise.getSemiFuture();
executor_.schedule([this,
threadpool_.schedule([this,
dtype,
&args,
&quant_args,
Expand All @@ -137,7 +137,7 @@ folly::SemiFuture<bool> Worker::init_kv_cache_async(
const std::vector<int64_t>& value_cache_shape) {
folly::Promise<bool> promise;
auto future = promise.getSemiFuture();
executor_.schedule([this,
threadpool_.schedule([this,
&key_cache_shape,
&value_cache_shape,
promise = std::move(promise)]() mutable {
Expand All @@ -152,7 +152,7 @@ folly::SemiFuture<folly::Unit> Worker::load_state_dict_async(
const StateDict& state_dict) {
folly::Promise<folly::Unit> promise;
auto future = promise.getSemiFuture();
executor_.schedule(
threadpool_.schedule(
[this, &state_dict, promise = std::move(promise)]() mutable {
// load the model weights from state_dict within the working thread
this->load_state_dict(state_dict);
Expand Down
4 changes: 2 additions & 2 deletions src/engine/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <folly/futures/Future.h>
#include <torch/torch.h>

#include "common/executor.h"
#include "common/threadpool.h"
#include "model_loader/state_dict.h"
#include "models/args.h"
#include "models/causal_lm.h"
Expand Down Expand Up @@ -83,7 +83,7 @@ class Worker final {

private:
// working thread
Executor executor_;
ThreadPool threadpool_;

// dtype of the model
torch::ScalarType dtype_;
Expand Down
6 changes: 3 additions & 3 deletions src/scheduler/continuous_batching_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ void ContinuousBatchingScheduler::on_request_finish(Request* request) {
block_manager_->release_slots_for_request(request);
// take over the ownership of the request
std::unique_ptr<Request> finished_request(request);
response_executor_.schedule([tokenizer = tokenizer_.get(),
request = std::move(finished_request)]() {
response_threadpool_.schedule([tokenizer = tokenizer_.get(),
request = std::move(finished_request)]() {
if (request->stream) {
// just finish the request
request->on_finish("", FinishReason::NONE, Status());
Expand All @@ -83,7 +83,7 @@ void ContinuousBatchingScheduler::on_sequence_stream(Sequence* seq) {
num_tokens_to_output >= FLAGS_streaming_token_buffer_size) {
const auto finish_reason = seq->finish_reason();
// output the delta text til the end of the sequence to the client
response_executor_.schedule(
response_threadpool_.schedule(
[seq, tokenizer = tokenizer_.get(), end = num_tokens, finish_reason]() {
const auto detla = seq->decode_delta_text(end, *tokenizer);
if (!detla.empty() || finish_reason != FinishReason::NONE) {
Expand Down
4 changes: 2 additions & 2 deletions src/scheduler/continuous_batching_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class ContinuousBatchingScheduler final : public Scheduler {
// low.
std::deque<Request*> preemptable_candidates_;

// the executor to handle responses
Executor response_executor_;
// the threadpool to handle responses
ThreadPool response_threadpool_;
};

} // namespace llm
2 changes: 1 addition & 1 deletion src/server/handlers/chat_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ ChatHandler::ChatHandler(Scheduler* scheduler, const Engine* engine)
}

void ChatHandler::chat_async(ChatCallData* call_data) {
converter_executor_.schedule([this, call_data = call_data]() {
converter_threadpool_.schedule([this, call_data = call_data]() {
if (!verify_request_arguments(call_data)) {
// request is not valid, finish with error
return;
Expand Down
6 changes: 3 additions & 3 deletions src/server/handlers/chat_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "server/call_data.h"
#include "chat.grpc.pb.h"
#include "common/executor.h"
#include "common/threadpool.h"
#include "engine/engine.h"
#include "models/args.h"
#include "scheduler/scheduler.h"
Expand All @@ -31,8 +31,8 @@ class ChatHandler final {
// model args
ModelArgs model_args_;

// converter executor
Executor converter_executor_;
// converter threadpool
ThreadPool converter_threadpool_;
};

} // namespace llm
2 changes: 1 addition & 1 deletion src/server/handlers/completion_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ CompletionHandler::CompletionHandler(Scheduler* scheduler, const Engine* engine)
}

void CompletionHandler::complete_async(CompletionCallData* call_data) {
converter_executor_.schedule([this, call_data = call_data]() {
converter_threadpool_.schedule([this, call_data = call_data]() {
if (!verify_request_arguments(call_data)) {
// request is not valid, finish with error
return;
Expand Down
6 changes: 3 additions & 3 deletions src/server/handlers/completion_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <thread>

#include "server/call_data.h"
#include "common/executor.h"
#include "common/threadpool.h"
#include "completion.grpc.pb.h"
#include "engine/engine.h"
#include "models/args.h"
Expand Down Expand Up @@ -32,8 +32,8 @@ class CompletionHandler final {
// model args
ModelArgs model_args_;

// converter executor
Executor converter_executor_;
// converter threadpool
ThreadPool converter_threadpool_;
};

} // namespace llm

0 comments on commit 7735b24

Please sign in to comment.