Skip to content

Commit

Permalink
Finish refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Aug 24, 2023
1 parent 094ec54 commit dd9cc90
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 31 deletions.
7 changes: 0 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ Some deviations from the original include:
- For models of 4096+ dimensions, thread pools are utilized to parallelize independent matrix
multiplications

## Refactoring TODOs

- sampler
- printer?
- generator?
- main

## Papers

- Standard transformer architecture: [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
Expand Down
1 change: 1 addition & 0 deletions src/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub const dot = @import("lib/dot.zig").dot;
pub const matmul = @import("lib/matmul.zig").matmul;
pub const matmul2 = @import("lib/matmul.zig").matmul2;
pub const matmul3 = @import("lib/matmul.zig").matmul3;
pub const print = @import("lib/print.zig").print;
pub const random = @import("lib/random.zig").random;
pub const rmsnorm = @import("lib/rmsnorm.zig").rmsnorm;
pub const rope = @import("lib/rope.zig").rope;
Expand Down
20 changes: 20 additions & 0 deletions src/lib/print.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
const std = @import("std");

pub fn print(word: []const u8) !void {
const stdout = std.io.getStdOut().writer();

// https://github.com/karpathy/llama2.c/blob/c7a26264a233c32f396b1c67be4ac019d2d8a659/run.c#L427
if (word.len == 6 and std.mem.eql(u8, word[0..3], "<0x") and word[5] == '>') {
const byte: ?u8 = std.fmt.parseInt(u8, word[3..5], 16) catch null;

if (byte) |char| {
if (std.ascii.isPrint(char) or std.ascii.isWhitespace(char)) {
try stdout.print("{s}", .{[_]u8{char}});
}
} else {
try stdout.print("{s}", .{word});
}
} else {
try stdout.print("{s}", .{word});
}
}
38 changes: 14 additions & 24 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ pub fn main() !void {

defer arena.deinit();

const allocator = arena.allocator();
const stdout = std.io.getStdOut().writer();
try generate(arena.allocator());
}

fn generate(allocator: std.mem.Allocator) !void {
var cli: Cli = undefined;

try cli.init(allocator);
Expand All @@ -32,11 +33,9 @@ pub fn main() !void {
cli.n_steps = checkpoint.seq_len;
}

const vocab_size = checkpoint.vocab_size;

var tokenizer: Tokenizer = undefined;

try tokenizer.init(allocator, cli.tokenizer_path, vocab_size);
try tokenizer.init(allocator, cli.tokenizer_path, checkpoint.vocab_size);
defer tokenizer.deinit();

var prompt_tokens = try tokenizer.encode(allocator, cli.input_prompt, true, false);
Expand All @@ -48,20 +47,22 @@ pub fn main() !void {
try transformer.init(allocator, &checkpoint);
defer transformer.deinit();

std.debug.assert(prompt_tokens.len > 0);

var current_token: usize = prompt_tokens[0];

prompt_tokens = prompt_tokens[1..];

var next_token: usize = 0; // TODO: null
var rng_state = cli.random_seed;

var probability_index_pairs_buffer: []lib.ProbabilityIndexPair =
try allocator.alloc(lib.ProbabilityIndexPair, vocab_size);
try allocator.alloc(lib.ProbabilityIndexPair, checkpoint.vocab_size);

var step: usize = 0;
defer allocator.free(probability_index_pairs_buffer);

var start_time: i64 = 0;
var total_time: i64 = 0;
var next_token: usize = 1;
var rng_state = cli.random_seed;
var step: usize = 0;

for (0..@min(cli.n_steps, checkpoint.seq_len)) |pos| {
if (pos > 0) {
Expand Down Expand Up @@ -107,24 +108,13 @@ pub fn main() !void {

const word = tokenizer.decode(current_token, next_token);

// https://github.com/karpathy/llama2.c/blob/c7a26264a233c32f396b1c67be4ac019d2d8a659/run.c#L427
if (word.len == 6 and std.mem.eql(u8, word[0..3], "<0x") and word[5] == '>') {
const byte: ?u8 = std.fmt.parseInt(u8, word[3..5], 16) catch null;

if (byte) |char| {
if (std.ascii.isPrint(char) or std.ascii.isWhitespace(char)) {
try stdout.print("{s}", .{[_]u8{char}});
}
} else {
try stdout.print("{s}", .{word});
}
} else {
try stdout.print("{s}", .{word});
}
try lib.print(word);

current_token = next_token;
}

const stdout = std.io.getStdOut().writer();

if (total_time > 0 and !cli.test_mode) {
const average_time = @as(f32, @floatFromInt(total_time)) / @as(f32, @floatFromInt(step));

Expand Down

0 comments on commit dd9cc90

Please sign in to comment.