Skip to content

Commit

Permalink
feat: add gpuLayers option to the chat command
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd committed Oct 9, 2023
1 parent 6b4b89f commit 0b4ae99
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type ChatCommand = {
temperature: number,
topK: number,
topP: number,
gpuLayers?: number,
repeatPenalty: number,
lastTokensRepeatPenalty: number,
penalizeRepeatingNewLine: boolean,
Expand Down Expand Up @@ -122,6 +123,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
description: "Dynamically selects the smallest set of tokens whose cumulative probability exceeds the threshold P, and samples the next token only from this set. A float number between `0` and `1`. Set to `1` to disable. Only relevant when `temperature` is set to a value greater than `0`.",
group: "Optional:"
})
.option("gpuLayers", {
alias: "gl",
type: "number",
description: "number of layers to store in VRAM",
group: "Optional:"
})
.option("repeatPenalty", {
alias: "rp",
type: "number",
Expand Down Expand Up @@ -165,12 +172,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
},
async handler({
model, systemInfo, systemPrompt, prompt, wrapper, contextSize,
grammar, threads, temperature, topK, topP, repeatPenalty,
grammar, threads, temperature, topK, topP, gpuLayers, repeatPenalty,
lastTokensRepeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens
}) {
try {
await RunChat({
model, systemInfo, systemPrompt, prompt, wrapper, contextSize, grammar, threads, temperature, topK, topP,
model, systemInfo, systemPrompt, prompt, wrapper, contextSize, grammar, threads, temperature, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens
});
} catch (err) {
Expand All @@ -183,7 +190,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {

async function RunChat({
model: modelArg, systemInfo, systemPrompt, prompt, wrapper, contextSize, grammar: grammarArg, threads, temperature, topK, topP,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens
}: ChatCommand) {
const {LlamaChatSession} = await import("../../llamaEvaluator/LlamaChatSession.js");
const {LlamaModel} = await import("../../llamaEvaluator/LlamaModel.js");
Expand All @@ -192,7 +199,8 @@ async function RunChat({

let initialPrompt = prompt ?? null;
const model = new LlamaModel({
modelPath: path.resolve(process.cwd(), modelArg)
modelPath: path.resolve(process.cwd(), modelArg),
gpuLayers: gpuLayers != null ? gpuLayers : undefined
});
const context = new LlamaContext({
model,
Expand Down

0 comments on commit 0b4ae99

Please sign in to comment.