From dd9cc90092979ceb3f569d91c84d6463870caec8 Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Thu, 24 Aug 2023 11:50:06 +0200 Subject: [PATCH] Finish refactoring --- README.md | 7 ------- src/lib.zig | 1 + src/lib/print.zig | 20 ++++++++++++++++++++ src/main.zig | 38 ++++++++++++++------------------------ 4 files changed, 35 insertions(+), 31 deletions(-) create mode 100644 src/lib/print.zig diff --git a/README.md b/README.md index 09bc575..4195718 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,6 @@ Some deviations from the original include: - For models of 4096+ dimensions, thread pools are utilized to parallelize independent matrix multiplications -## Refactoring TODOs - -- sampler -- printer? -- generator? -- main - ## Papers - Standard transformer architecture: [Attention Is All You Need](https://arxiv.org/abs/1706.03762) diff --git a/src/lib.zig b/src/lib.zig index 84d8a95..6541760 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -4,6 +4,7 @@ pub const dot = @import("lib/dot.zig").dot; pub const matmul = @import("lib/matmul.zig").matmul; pub const matmul2 = @import("lib/matmul.zig").matmul2; pub const matmul3 = @import("lib/matmul.zig").matmul3; +pub const print = @import("lib/print.zig").print; pub const random = @import("lib/random.zig").random; pub const rmsnorm = @import("lib/rmsnorm.zig").rmsnorm; pub const rope = @import("lib/rope.zig").rope; diff --git a/src/lib/print.zig b/src/lib/print.zig new file mode 100644 index 0000000..ecd0e84 --- /dev/null +++ b/src/lib/print.zig @@ -0,0 +1,20 @@ +const std = @import("std"); + +pub fn print(word: []const u8) !void { + const stdout = std.io.getStdOut().writer(); + + // https://github.com/karpathy/llama2.c/blob/c7a26264a233c32f396b1c67be4ac019d2d8a659/run.c#L427 + if (word.len == 6 and std.mem.eql(u8, word[0..3], "<0x") and word[5] == '>') { + const byte: ?u8 = std.fmt.parseInt(u8, word[3..5], 16) catch null; + + if (byte) |char| { + if (std.ascii.isPrint(char) or std.ascii.isWhitespace(char)) { + try stdout.print("{s}", .{[_]u8{char}}); + } + } else { + try stdout.print("{s}", .{word}); + } + } else { + try stdout.print("{s}", .{word}); + } +} diff --git a/src/main.zig b/src/main.zig index 2a9062c..c5f7a7b 100644 --- a/src/main.zig +++ b/src/main.zig @@ -10,9 +10,10 @@ pub fn main() !void { defer arena.deinit(); - const allocator = arena.allocator(); - const stdout = std.io.getStdOut().writer(); + try generate(arena.allocator()); +} +fn generate(allocator: std.mem.Allocator) !void { var cli: Cli = undefined; try cli.init(allocator); @@ -32,11 +33,9 @@ pub fn main() !void { cli.n_steps = checkpoint.seq_len; } - const vocab_size = checkpoint.vocab_size; - var tokenizer: Tokenizer = undefined; - try tokenizer.init(allocator, cli.tokenizer_path, vocab_size); + try tokenizer.init(allocator, cli.tokenizer_path, checkpoint.vocab_size); defer tokenizer.deinit(); var prompt_tokens = try tokenizer.encode(allocator, cli.input_prompt, true, false); @@ -48,20 +47,22 @@ pub fn main() !void { try transformer.init(allocator, &checkpoint); defer transformer.deinit(); + std.debug.assert(prompt_tokens.len > 0); + var current_token: usize = prompt_tokens[0]; prompt_tokens = prompt_tokens[1..]; - var next_token: usize = 0; // TODO: null - var rng_state = cli.random_seed; - var probability_index_pairs_buffer: []lib.ProbabilityIndexPair = - try allocator.alloc(lib.ProbabilityIndexPair, vocab_size); + try allocator.alloc(lib.ProbabilityIndexPair, checkpoint.vocab_size); - var step: usize = 0; + defer allocator.free(probability_index_pairs_buffer); var start_time: i64 = 0; var total_time: i64 = 0; + var next_token: usize = 1; + var rng_state = cli.random_seed; + var step: usize = 0; for (0..@min(cli.n_steps, checkpoint.seq_len)) |pos| { if (pos > 0) { @@ -107,24 +108,13 @@ pub fn main() !void { const word = tokenizer.decode(current_token, next_token); - // https://github.com/karpathy/llama2.c/blob/c7a26264a233c32f396b1c67be4ac019d2d8a659/run.c#L427 - if (word.len == 6 and std.mem.eql(u8, word[0..3], "<0x") and word[5] == '>') { - const byte: ?u8 = std.fmt.parseInt(u8, word[3..5], 16) catch null; - - if (byte) |char| { - if (std.ascii.isPrint(char) or std.ascii.isWhitespace(char)) { - try stdout.print("{s}", .{[_]u8{char}}); - } - } else { - try stdout.print("{s}", .{word}); - } - } else { - try stdout.print("{s}", .{word}); - } + try lib.print(word); current_token = next_token; } + const stdout = std.io.getStdOut().writer(); + if (total_time > 0 and !cli.test_mode) { const average_time = @as(f32, @floatFromInt(total_time)) / @as(f32, @floatFromInt(step));