From aa183c234207c9438c28ab8c11ac1d866c8571e5 Mon Sep 17 00:00:00 2001 From: Clemens Akens Date: Thu, 19 Oct 2023 17:54:30 +0200 Subject: [PATCH] Improve chat CLI --- src/chat.zig | 8 ++++---- src/chat_args.zig | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/chat.zig b/src/chat.zig index a8327e2..8b75a81 100644 --- a/src/chat.zig +++ b/src/chat.zig @@ -11,8 +11,8 @@ allocator: std.mem.Allocator, transformer: Transformer, tokenizer: Tokenizer, sampler: Sampler, -user_prompt: []const u8, system_prompt: []const u8, +user_prompt: []const u8, pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self { const transformer = try Transformer.init(allocator, args.model_path, args.n_steps); @@ -33,8 +33,8 @@ pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self { .transformer = transformer, .tokenizer = tokenizer, .sampler = sampler, - .user_prompt = args.prompt, .system_prompt = args.system_prompt, + .user_prompt = args.user_prompt, }; } @@ -44,10 +44,10 @@ pub fn deinit(self: *const Self) void { self.sampler.deinit(); } -const user_prompt_template_start = "[INST] "; -const user_prompt_template_close = " [/INST]"; const system_prompt_template_start = "<>\n"; const system_prompt_template_close = "\n<>\n\n"; +const user_prompt_template_start = "[INST] "; +const user_prompt_template_close = " [/INST]"; const bos_token = 1; // beginning of sequence const eos_token = 2; // end of sequence diff --git a/src/chat_args.zig b/src/chat_args.zig index 9a7a0d9..0f05cc2 100644 --- a/src/chat_args.zig +++ b/src/chat_args.zig @@ -8,10 +8,10 @@ temperature: f32, top_p: f32, random_seed: u64, n_steps: usize, -prompt: []const u8, system_prompt: []const u8, +user_prompt: []const u8, -const Option = enum { temperature, top_p, random_seed, n_steps, prompt, system_prompt }; +const Option = enum { temperature, top_p, random_seed, n_steps, system_prompt, user_prompt }; pub fn init(allocator: std.mem.Allocator) !Self { var arg_iterator = try std.process.argsWithAllocator(allocator); @@ -27,8 +27,8 @@ pub fn init(allocator: std.mem.Allocator) !Self { var top_p: ?f32 = null; var random_seed: ?u64 = null; var n_steps: ?usize = null; - var prompt: ?[]const u8 = null; var system_prompt: ?[]const u8 = null; + var user_prompt: ?[]const u8 = null; while (arg_iterator.next()) |arg| { if (current_option) |option| { @@ -40,10 +40,10 @@ pub fn init(allocator: std.mem.Allocator) !Self { random_seed = try std.fmt.parseInt(u64, arg, 10); } else if (option == .n_steps and n_steps == null) { n_steps = try std.fmt.parseInt(usize, arg, 10); - } else if (option == .prompt and prompt == null) { - prompt = arg; } else if (option == .system_prompt and system_prompt == null) { system_prompt = arg; + } else if (option == .user_prompt and user_prompt == null) { + user_prompt = arg; } else { try help(1); } @@ -57,10 +57,10 @@ pub fn init(allocator: std.mem.Allocator) !Self { current_option = .random_seed; } else if (std.mem.eql(u8, arg, "--n_steps")) { current_option = .n_steps; - } else if (std.mem.eql(u8, arg, "--prompt")) { - current_option = .prompt; } else if (std.mem.eql(u8, arg, "--system_prompt")) { current_option = .system_prompt; + } else if (std.mem.eql(u8, arg, "--user_prompt")) { + current_option = .user_prompt; } else { try help(if (std.mem.eql(u8, arg, "--help")) 0 else 1); } @@ -77,8 +77,8 @@ pub fn init(allocator: std.mem.Allocator) !Self { .top_p = @max(@min(top_p orelse 0.9, 1), 0), .random_seed = random_seed orelse @intCast(std.time.milliTimestamp()), .n_steps = n_steps orelse 0, - .prompt = prompt orelse "", .system_prompt = system_prompt orelse "", + .user_prompt = user_prompt orelse "", }; } @@ -99,8 +99,8 @@ fn help(exit_status: u8) !noreturn { try console.print(" --top_p = 0.9\n", .{}); try console.print(" --random_seed = \n", .{}); try console.print(" --n_steps = \n", .{}); - try console.print(" --prompt = \"\"\n", .{}); try console.print(" --system_prompt = \"\"\n", .{}); + try console.print(" --user_prompt = \"\"\n", .{}); try console.print(" --help\n", .{}); std.process.exit(exit_status);