diff --git a/README.md b/README.md index b6d7a8b1..2bc2cdba 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,8 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url)); const model = new LlamaModel({ modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin") }); -const session = new LlamaChatSession({context: model.createContext()}); +const context = new LlamaContext({model}); +const session = new LlamaChatSession({context}); const q1 = "Hi there, how are you?"; @@ -73,7 +74,8 @@ const model = new LlamaModel({ modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin"), promptWrapper: new MyCustomChatPromptWrapper() // by default, LlamaChatPromptWrapper is used }) -const session = new LlamaChatSession({context: model.createContext()}); +const context = new LlamaContext({model}); +const session = new LlamaChatSession({context}); const q1 = "Hi there, how are you?"; @@ -102,7 +104,7 @@ const model = new LlamaModel({ modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin") }); -const context = model.createContext(); +const context = new LlamaContext({model}); const q1 = "Hi there, how are you?"; console.log("AI: " + q1); diff --git a/llama/addon.cpp b/llama/addon.cpp index 72358e0c..b6f5fb78 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -8,21 +8,80 @@ class LLAMAModel : public Napi::ObjectWrap { public: - llama_context_params params; - llama_model* model; - LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { - params = llama_context_default_params(); - params.seed = -1; - params.n_ctx = 4096; - model = llama_load_model_from_file(info[0].As().Utf8Value().c_str(), params); - - if (model == NULL) { - Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException(); - return; + llama_context_params params; + llama_model* model; + + LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap(info) { + params = llama_context_default_params(); + params.seed = -1; + params.n_ctx = 4096; + + // Get the model path + std::string modelPath = info[0].As().Utf8Value(); + + if (info.Length() > 1 && info[1].IsObject()) { + Napi::Object options = info[1].As(); + + if (options.Has("seed")) { + params.seed = options.Get("seed").As().Int32Value(); + } + + if (options.Has("contextSize")) { + params.n_ctx = options.Get("contextSize").As().Int32Value(); + } + + if (options.Has("batchSize")) { + params.n_batch = options.Get("batchSize").As().Int32Value(); + } + + if (options.Has("gpuCores")) { + params.n_gpu_layers = options.Get("gpuCores").As().Int32Value(); + } + + if (options.Has("lowVram")) { + params.low_vram = options.Get("lowVram").As().Value(); + } + + if (options.Has("f16Kv")) { + params.f16_kv = options.Get("f16Kv").As().Value(); + } + + if (options.Has("logitsAll")) { + params.logits_all = options.Get("logitsAll").As().Value(); + } + + if (options.Has("vocabOnly")) { + params.vocab_only = options.Get("vocabOnly").As().Value(); + } + + if (options.Has("useMmap")) { + params.use_mmap = options.Get("useMmap").As().Value(); + } + + if (options.Has("useMlock")) { + params.use_mlock = options.Get("useMlock").As().Value(); + } + + if (options.Has("embedding")) { + params.embedding = options.Get("embedding").As().Value(); + } + } + + model = llama_load_model_from_file(modelPath.c_str(), params); + + if (model == NULL) { + Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException(); + return; + } + } + + ~LLAMAModel() { + llama_free_model(model); + } + + static void init(Napi::Object exports) { + exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {})); } - } - ~LLAMAModel() { llama_free_model(model); } - static void init(Napi::Object exports) { exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {})); } }; class LLAMAContext : public Napi::ObjectWrap { diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index 973b4211..198ef6a6 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -11,7 +11,8 @@ type ChatCommand = { model: string, systemInfo: boolean, systemPrompt: string, - wrapper: string + wrapper: string, + contextSize: number }; export const ChatCommand: CommandModule = { @@ -46,11 +47,17 @@ export const ChatCommand: CommandModule = { choices: ["general", "llama"], description: "Chat wrapper to use", group: "Optional:" + }) + .option("contextSize", { + type: "number", + default: 1024 * 4, + description: "Context size to use for the model", + group: "Optional:" }); }, - async handler({model, systemInfo, systemPrompt, wrapper}) { + async handler({model, systemInfo, systemPrompt, wrapper, contextSize}) { try { - await RunChat({model, systemInfo, systemPrompt, wrapper}); + await RunChat({model, systemInfo, systemPrompt, wrapper, contextSize}); } catch (err) { console.error(err); process.exit(1); @@ -59,15 +66,18 @@ export const ChatCommand: CommandModule = { }; -async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper}: ChatCommand) { +async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper, contextSize}: ChatCommand) { const {LlamaChatSession} = await import("../../llamaEvaluator/LlamaChatSession.js"); const {LlamaModel} = await import("../../llamaEvaluator/LlamaModel.js"); + const {LlamaContext} = await import("../../llamaEvaluator/LlamaContext.js"); const model = new LlamaModel({ - modelPath: modelArg + modelPath: modelArg, + contextSize }); + const context = new LlamaContext({model}); const session = new LlamaChatSession({ - context: model.createContext(), + context, printLLamaSystemInfo: systemInfo, systemPrompt, promptWrapper: createChatWrapper(wrapper) diff --git a/src/llamaEvaluator/LlamaChatSession.ts b/src/llamaEvaluator/LlamaChatSession.ts index a1a6cbd1..b3be58a6 100644 --- a/src/llamaEvaluator/LlamaChatSession.ts +++ b/src/llamaEvaluator/LlamaChatSession.ts @@ -6,7 +6,7 @@ import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper import {LlamaModel} from "./LlamaModel.js"; import {LlamaContext} from "./LlamaContext.js"; -const UNKNOWN_UNICODE_CHAR = "�"; +const UNKNOWN_UNICODE_CHAR = "\ufffd"; export class LlamaChatSession { private readonly _systemPrompt: string; diff --git a/src/llamaEvaluator/LlamaContext.ts b/src/llamaEvaluator/LlamaContext.ts index 15606f83..c3c46837 100644 --- a/src/llamaEvaluator/LlamaContext.ts +++ b/src/llamaEvaluator/LlamaContext.ts @@ -1,13 +1,12 @@ import {LLAMAContext, llamaCppNode} from "./LlamaBins.js"; +import {LlamaModel} from "./LlamaModel.js"; -type LlamaContextConstructorParameters = {prependBos: boolean, ctx: LLAMAContext}; export class LlamaContext { private readonly _ctx: LLAMAContext; private _prependBos: boolean; - /** @internal */ - public constructor( {ctx, prependBos}: LlamaContextConstructorParameters ) { - this._ctx = ctx; + public constructor({model, prependBos = true}: {model: LlamaModel, prependBos?: boolean}) { + this._ctx = new LLAMAContext(model._model); this._prependBos = prependBos; } diff --git a/src/llamaEvaluator/LlamaModel.ts b/src/llamaEvaluator/LlamaModel.ts index 545b3eae..5685a34a 100644 --- a/src/llamaEvaluator/LlamaModel.ts +++ b/src/llamaEvaluator/LlamaModel.ts @@ -1,24 +1,62 @@ -import {LlamaContext} from "./LlamaContext.js"; -import {LLAMAContext, llamaCppNode, LLAMAModel} from "./LlamaBins.js"; +import {llamaCppNode, LLAMAModel} from "./LlamaBins.js"; export class LlamaModel { - private readonly _model: LLAMAModel; - private readonly _prependBos: boolean; + /** @internal */ + public readonly _model: LLAMAModel; - public constructor({modelPath, prependBos = true}: { modelPath: string, prependBos?: boolean }) { - this._model = new LLAMAModel(modelPath); - this._prependBos = prependBos; - } - - public createContext() { - return new LlamaContext({ - ctx: new LLAMAContext(this._model), - prependBos: this._prependBos - }); + /** + * options source: + * https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/llama.h#L102 (struct llama_context_params) + * @param {object} options + * @param {string} options.modelPath - path to the model on the filesystem + * @param {number | null} [options.seed] - If null, a random seed will be used + * @param {number} [options.contextSize] - text context size + * @param {number} [options.batchSize] - prompt processing batch size + * @param {number} [options.gpuCores] - number of layers to store in VRAM + * @param {boolean} [options.lowVram] - if true, reduce VRAM usage at the cost of performance + * @param {boolean} [options.f16Kv] - use fp16 for KV cache + * @param {boolean} [options.logitsAll] - the llama_eval() call computes all logits, not just the last one + * @param {boolean} [options.vocabOnly] - only load the vocabulary, no weights + * @param {boolean} [options.useMmap] - use mmap if possible + * @param {boolean} [options.useMlock] - force system to keep model in RAM + * @param {boolean} [options.embedding] - embedding mode only + */ + public constructor({ + modelPath, seed = null, contextSize = 1024 * 4, batchSize, gpuCores, + lowVram, f16Kv, logitsAll, vocabOnly, useMmap, useMlock, embedding + }: { + modelPath: string, seed?: number | null, contextSize?: number, batchSize?: number, gpuCores?: number, + lowVram?: boolean, f16Kv?: boolean, logitsAll?: boolean, vocabOnly?: boolean, useMmap?: boolean, useMlock?: boolean, + embedding?: boolean + }) { + this._model = new LLAMAModel(modelPath, removeNullFields({ + seed: seed != null ? Math.max(-1, seed) : undefined, + contextSize, + batchSize, + gpuCores, + lowVram, + f16Kv, + logitsAll, + vocabOnly, + useMmap, + useMlock, + embedding + })); } public static get systemInfo() { return llamaCppNode.systemInfo(); } } + +function removeNullFields(obj: T): T { + const newObj: T = Object.assign({}, obj); + + for (const key in obj) { + if (newObj[key] == null) + delete newObj[key]; + } + + return newObj; +} diff --git a/src/utils/getBin.ts b/src/utils/getBin.ts index 1371626c..94cf4982 100644 --- a/src/utils/getBin.ts +++ b/src/utils/getBin.ts @@ -95,7 +95,19 @@ export type LlamaCppNodeModule = { }; export type LLAMAModel = { - new (modelPath: string): LLAMAModel + new (modelPath: string, params: { + seed?: number, + contextSize?: number, + batchSize?: number, + gpuCores?: number, + lowVram?: boolean, + f16Kv?: boolean, + logitsAll?: boolean, + vocabOnly?: boolean, + useMmap?: boolean, + useMlock?: boolean, + embedding?: boolean + }): LLAMAModel }; export type LLAMAContext = {