Skip to content

Commit

Permalink
refactor(thread): Make spawnThreadTasks use a wait group & replace …
Browse files Browse the repository at this point in the history
…task params with `TaskParams` struct (#300)

* Add `TaskParams` to simplify `spawnThreadTasks`

* Make `spawnThreadTasks` use a `WaitGroup` instead

* Use a config struct for `spawnThreadTasks`
The function now takes only two arguments: the task function, and
a config struct which is parameterized on the task function type info
in order to also contain the parameter tuple.
  • Loading branch information
InKryption authored Oct 8, 2024
1 parent f970d8b commit 534e83d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 85 deletions.
98 changes: 39 additions & 59 deletions src/accountsdb/db.zig
Original file line number Diff line number Diff line change
Expand Up @@ -440,24 +440,19 @@ pub const AccountsDB = struct {

self.logger.info().logf("[{d} threads]: reading and indexing accounts...", .{n_parse_threads});
{
var handles = std.ArrayList(std.Thread).init(self.allocator);
defer {
for (handles.items) |*h| h.join();
handles.deinit();
}

try spawnThreadTasks(
&handles,
loadAndVerifyAccountsFilesMultiThread,
.{
var wg: std.Thread.WaitGroup = .{};
defer wg.wait();
try spawnThreadTasks(loadAndVerifyAccountsFilesMultiThread, .{
.wg = &wg,
.data_len = n_account_files,
.max_threads = n_parse_threads,
.params = .{
loading_threads.items,
accounts_dir,
snapshot_manifest.file_map,
accounts_per_file_estimate,
},
n_account_files,
n_parse_threads,
);
});
}

// if geyser, send end of data signal
Expand All @@ -480,20 +475,16 @@ pub const AccountsDB = struct {
accounts_dir: std.fs.Dir,
file_info_map: AccountsDbFields.FileMap,
accounts_per_file_estimate: u64,
// task specific
start_index: usize,
end_index: usize,
thread_id: usize,
task: sig.utils.thread.TaskParams,
) !void {
const thread_db = &loading_threads[thread_id];

const thread_db = &loading_threads[task.thread_id];
try thread_db.loadAndVerifyAccountsFiles(
accounts_dir,
accounts_per_file_estimate,
file_info_map,
start_index,
end_index,
thread_id == 0,
task.start_index,
task.end_index,
task.thread_id == 0,
);
}

Expand Down Expand Up @@ -711,22 +702,18 @@ pub const AccountsDB = struct {
thread_dbs: []AccountsDB,
n_threads: usize,
) !void {
var handles = std.ArrayList(std.Thread).init(self.allocator);
defer {
for (handles.items) |*h| h.join();
handles.deinit();
}
try spawnThreadTasks(
&handles,
combineThreadIndexesMultiThread,
.{
var combine_indexes_wg: std.Thread.WaitGroup = .{};
defer combine_indexes_wg.wait();
try spawnThreadTasks(combineThreadIndexesMultiThread, .{
.wg = &combine_indexes_wg,
.data_len = self.account_index.numberOfBins(),
.max_threads = n_threads,
.params = .{
self.logger,
&self.account_index,
thread_dbs,
},
self.account_index.numberOfBins(),
n_threads,
);
});

// ensure enough capacity
var ref_mem_capacity: u32 = 0;
Expand Down Expand Up @@ -788,15 +775,15 @@ pub const AccountsDB = struct {
logger: Logger,
index: *AccountIndex,
thread_dbs: []const AccountsDB,
// task specific
bin_start_index: usize,
bin_end_index: usize,
thread_id: usize,
task: sig.utils.thread.TaskParams,
) !void {
const bin_start_index = task.start_index;
const bin_end_index = task.end_index;

const total_bins = bin_end_index - bin_start_index;
var timer = try sig.time.Timer.start();
var progress_timer = try std.time.Timer.start();
const print_progress = thread_id == 0;
const print_progress = task.thread_id == 0;

for (bin_start_index..bin_end_index, 1..) |bin_index, iteration_count| {
// sum size across threads
Expand Down Expand Up @@ -890,24 +877,20 @@ pub const AccountsDB = struct {
self.logger.info().logf("collecting hashes from accounts...", .{});

{
var handles = std.ArrayList(std.Thread).init(self.allocator);
defer {
for (handles.items) |*h| h.join();
handles.deinit();
}
try spawnThreadTasks(
&handles,
getHashesFromIndexMultiThread,
.{
var wg: std.Thread.WaitGroup = .{};
defer wg.wait();
try spawnThreadTasks(getHashesFromIndexMultiThread, .{
.wg = &wg,
.data_len = self.account_index.numberOfBins(),
.max_threads = n_threads,
.params = .{
self,
config,
self.allocator,
hashes,
lamports,
},
self.account_index.numberOfBins(),
n_threads,
);
});
}

self.logger.debug().logf("took: {s}", .{std.fmt.fmtDuration(timer.read())});
Expand Down Expand Up @@ -1008,19 +991,16 @@ pub const AccountsDB = struct {
hashes_allocator: std.mem.Allocator,
hashes: []ArrayListUnmanaged(Hash),
total_lamports: []u64,
// spawing thread specific params
bin_start_index: usize,
bin_end_index: usize,
thread_index: usize,
task: sig.utils.thread.TaskParams,
) !void {
try getHashesFromIndex(
self,
config,
self.account_index.bins[bin_start_index..bin_end_index],
self.account_index.bins[task.start_index..task.end_index],
hashes_allocator,
&hashes[thread_index],
&total_lamports[thread_index],
thread_index == 0,
&hashes[task.thread_id],
&total_lamports[task.thread_id],
task.thread_id == 0,
);
}

Expand Down
88 changes: 62 additions & 26 deletions src/utils/thread.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,77 @@ const Mutex = std.Thread.Mutex;
const ThreadPool = @import("../sync/thread_pool.zig").ThreadPool;
const Batch = ThreadPool.Batch;

/// Spawns tasks and outputs the list of handles for the spawned threads.
/// Task function should accept `params ++ .{ start_index, end_index, thread_id }` as its parameter tuple.
pub const TaskParams = struct {
start_index: usize,
end_index: usize,
thread_id: usize,
};

fn chunkSizeAndThreadCount(data_len: usize, max_n_threads: usize) struct { usize, usize } {
var chunk_size = data_len / max_n_threads;
var n_threads = max_n_threads;
if (chunk_size == 0) {
n_threads = 1;
chunk_size = data_len;
}
return .{ chunk_size, n_threads };
}

pub fn SpawnThreadTasksConfig(comptime TaskFn: type) type {
return struct {
wg: *std.Thread.WaitGroup,
data_len: usize,
max_threads: usize,
/// If non-null, set to the coverage over the data which was achieved.
/// On a successful call, this will be equal to `data_len`.
/// On a failed call, this will be less than `data_len`,
/// representing the length of the data which was successfully
coverage: ?*usize = null,
params: Params,

pub const Params = std.meta.ArgsTuple(@Type(.{ .Fn = blk: {
var info = @typeInfo(TaskFn).Fn;
info.params = info.params[0 .. info.params.len - 1];
break :blk info;
} }));
};
}

pub fn spawnThreadTasks(
/// This list is cleared, and then filled with the handles for the spawned task threads.
/// On successful call, all threads were appropriately spawned.
handles: *std.ArrayList(std.Thread),
comptime taskFn: anytype,
params: anytype,
data_len: usize,
max_n_threads: usize,
) (std.mem.Allocator.Error || std.Thread.SpawnError)!void {
const chunk_size, const n_threads = blk: {
var chunk_size = data_len / max_n_threads;
var n_threads = max_n_threads;
if (chunk_size == 0) {
n_threads = 1;
chunk_size = data_len;
config: SpawnThreadTasksConfig(@TypeOf(taskFn)),
) std.Thread.SpawnError!void {
const Config = SpawnThreadTasksConfig(@TypeOf(taskFn));
const chunk_size, const n_threads = chunkSizeAndThreadCount(config.data_len, config.max_threads);

if (config.coverage) |coverage| coverage.* = 0;

const S = struct {
fn taskFnWg(wg: *std.Thread.WaitGroup, fn_params: Config.Params, task_params: TaskParams) @typeInfo(@TypeOf(taskFn)).Fn.return_type.? {
defer wg.finish();
return @call(.auto, taskFn, fn_params ++ .{task_params});
}
break :blk .{ chunk_size, n_threads };
};

handles.clearRetainingCapacity();
try handles.ensureTotalCapacityPrecise(n_threads);

var start_index: usize = 0;
for (0..n_threads) |thread_id| {
const end_index = if (thread_id == n_threads - 1) data_len else (start_index + chunk_size);
// NOTE(trevor): instead of just `try`ing, we could fill an optional diagnostic struct
// which inform the caller how much coverage over `data_len` was achieved,
// so that they could handle its coverage themselves instead of just having
// to kill all the successfully spawned threads.
const handle = try std.Thread.spawn(.{}, taskFn, params ++ .{ start_index, end_index, thread_id });
handles.appendAssumeCapacity(handle);
const end_index = if (thread_id == n_threads - 1) config.data_len else (start_index + chunk_size);
const task_params: TaskParams = .{
.start_index = start_index,
.end_index = end_index,
.thread_id = thread_id,
};

config.wg.start();
const handle = std.Thread.spawn(.{}, S.taskFnWg, .{ config.wg, config.params, task_params }) catch |err| {
if (config.coverage) |coverage| coverage.* = start_index;
return err;
};
handle.detach();
start_index = end_index;
}

if (config.coverage) |coverage| coverage.* = config.data_len;
}

pub fn ThreadPoolTask(comptime Entry: type) type {
Expand Down

0 comments on commit 534e83d

Please sign in to comment.