Skip to content

Commit

Permalink
Improve chat CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent ee33631 commit aa183c2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ allocator: std.mem.Allocator,
transformer: Transformer,
tokenizer: Tokenizer,
sampler: Sampler,
user_prompt: []const u8,
system_prompt: []const u8,
user_prompt: []const u8,

pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self {
const transformer = try Transformer.init(allocator, args.model_path, args.n_steps);
Expand All @@ -33,8 +33,8 @@ pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self {
.transformer = transformer,
.tokenizer = tokenizer,
.sampler = sampler,
.user_prompt = args.prompt,
.system_prompt = args.system_prompt,
.user_prompt = args.user_prompt,
};
}

Expand All @@ -44,10 +44,10 @@ pub fn deinit(self: *const Self) void {
self.sampler.deinit();
}

const user_prompt_template_start = "[INST] ";
const user_prompt_template_close = " [/INST]";
const system_prompt_template_start = "<<SYS>>\n";
const system_prompt_template_close = "\n<</SYS>>\n\n";
const user_prompt_template_start = "[INST] ";
const user_prompt_template_close = " [/INST]";

const bos_token = 1; // beginning of sequence
const eos_token = 2; // end of sequence
Expand Down
18 changes: 9 additions & 9 deletions src/chat_args.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ temperature: f32,
top_p: f32,
random_seed: u64,
n_steps: usize,
prompt: []const u8,
system_prompt: []const u8,
user_prompt: []const u8,

const Option = enum { temperature, top_p, random_seed, n_steps, prompt, system_prompt };
const Option = enum { temperature, top_p, random_seed, n_steps, system_prompt, user_prompt };

pub fn init(allocator: std.mem.Allocator) !Self {
var arg_iterator = try std.process.argsWithAllocator(allocator);
Expand All @@ -27,8 +27,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
var top_p: ?f32 = null;
var random_seed: ?u64 = null;
var n_steps: ?usize = null;
var prompt: ?[]const u8 = null;
var system_prompt: ?[]const u8 = null;
var user_prompt: ?[]const u8 = null;

while (arg_iterator.next()) |arg| {
if (current_option) |option| {
Expand All @@ -40,10 +40,10 @@ pub fn init(allocator: std.mem.Allocator) !Self {
random_seed = try std.fmt.parseInt(u64, arg, 10);
} else if (option == .n_steps and n_steps == null) {
n_steps = try std.fmt.parseInt(usize, arg, 10);
} else if (option == .prompt and prompt == null) {
prompt = arg;
} else if (option == .system_prompt and system_prompt == null) {
system_prompt = arg;
} else if (option == .user_prompt and user_prompt == null) {
user_prompt = arg;
} else {
try help(1);
}
Expand All @@ -57,10 +57,10 @@ pub fn init(allocator: std.mem.Allocator) !Self {
current_option = .random_seed;
} else if (std.mem.eql(u8, arg, "--n_steps")) {
current_option = .n_steps;
} else if (std.mem.eql(u8, arg, "--prompt")) {
current_option = .prompt;
} else if (std.mem.eql(u8, arg, "--system_prompt")) {
current_option = .system_prompt;
} else if (std.mem.eql(u8, arg, "--user_prompt")) {
current_option = .user_prompt;
} else {
try help(if (std.mem.eql(u8, arg, "--help")) 0 else 1);
}
Expand All @@ -77,8 +77,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
.top_p = @max(@min(top_p orelse 0.9, 1), 0),
.random_seed = random_seed orelse @intCast(std.time.milliTimestamp()),
.n_steps = n_steps orelse 0,
.prompt = prompt orelse "",
.system_prompt = system_prompt orelse "",
.user_prompt = user_prompt orelse "",
};
}

Expand All @@ -99,8 +99,8 @@ fn help(exit_status: u8) !noreturn {
try console.print(" --top_p <float> = 0.9\n", .{});
try console.print(" --random_seed <int> = <milli_timestamp>\n", .{});
try console.print(" --n_steps <int> = <max_sequence_length>\n", .{});
try console.print(" --prompt <string> = \"\"\n", .{});
try console.print(" --system_prompt <string> = \"\"\n", .{});
try console.print(" --user_prompt <string> = \"\"\n", .{});
try console.print(" --help\n", .{});

std.process.exit(exit_status);
Expand Down

0 comments on commit aa183c2

Please sign in to comment.