Skip to content

Commit

Permalink
feat: add support for some llama.cpp params on LlamaModel (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Aug 17, 2023
1 parent adfce4c commit c76ec48
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 43 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url));
const model = new LlamaModel({
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin")
});
const session = new LlamaChatSession({context: model.createContext()});
const context = new LlamaContext({model});
const session = new LlamaChatSession({context});


const q1 = "Hi there, how are you?";
Expand Down Expand Up @@ -73,7 +74,8 @@ const model = new LlamaModel({
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin"),
promptWrapper: new MyCustomChatPromptWrapper() // by default, LlamaChatPromptWrapper is used
})
const session = new LlamaChatSession({context: model.createContext()});
const context = new LlamaContext({model});
const session = new LlamaChatSession({context});


const q1 = "Hi there, how are you?";
Expand Down Expand Up @@ -102,7 +104,7 @@ const model = new LlamaModel({
modelPath: path.join(__dirname, "models", "vicuna-13b-v1.5-16k.ggmlv3.q5_1.bin")
});

const context = model.createContext();
const context = new LlamaContext({model});

const q1 = "Hi there, how are you?";
console.log("AI: " + q1);
Expand Down
87 changes: 73 additions & 14 deletions llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,80 @@

class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
public:
llama_context_params params;
llama_model* model;
LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAModel>(info) {
params = llama_context_default_params();
params.seed = -1;
params.n_ctx = 4096;
model = llama_load_model_from_file(info[0].As<Napi::String>().Utf8Value().c_str(), params);

if (model == NULL) {
Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
return;
llama_context_params params;
llama_model* model;

LLAMAModel(const Napi::CallbackInfo& info) : Napi::ObjectWrap<LLAMAModel>(info) {
params = llama_context_default_params();
params.seed = -1;
params.n_ctx = 4096;

// Get the model path
std::string modelPath = info[0].As<Napi::String>().Utf8Value();

if (info.Length() > 1 && info[1].IsObject()) {
Napi::Object options = info[1].As<Napi::Object>();

if (options.Has("seed")) {
params.seed = options.Get("seed").As<Napi::Number>().Int32Value();
}

if (options.Has("contextSize")) {
params.n_ctx = options.Get("contextSize").As<Napi::Number>().Int32Value();
}

if (options.Has("batchSize")) {
params.n_batch = options.Get("batchSize").As<Napi::Number>().Int32Value();
}

if (options.Has("gpuCores")) {
params.n_gpu_layers = options.Get("gpuCores").As<Napi::Number>().Int32Value();
}

if (options.Has("lowVram")) {
params.low_vram = options.Get("lowVram").As<Napi::Boolean>().Value();
}

if (options.Has("f16Kv")) {
params.f16_kv = options.Get("f16Kv").As<Napi::Boolean>().Value();
}

if (options.Has("logitsAll")) {
params.logits_all = options.Get("logitsAll").As<Napi::Boolean>().Value();
}

if (options.Has("vocabOnly")) {
params.vocab_only = options.Get("vocabOnly").As<Napi::Boolean>().Value();
}

if (options.Has("useMmap")) {
params.use_mmap = options.Get("useMmap").As<Napi::Boolean>().Value();
}

if (options.Has("useMlock")) {
params.use_mlock = options.Get("useMlock").As<Napi::Boolean>().Value();
}

if (options.Has("embedding")) {
params.embedding = options.Get("embedding").As<Napi::Boolean>().Value();
}
}

model = llama_load_model_from_file(modelPath.c_str(), params);

if (model == NULL) {
Napi::Error::New(info.Env(), "Failed to load model").ThrowAsJavaScriptException();
return;
}
}

~LLAMAModel() {
llama_free_model(model);
}

static void init(Napi::Object exports) {
exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {}));
}
}
~LLAMAModel() { llama_free_model(model); }
static void init(Napi::Object exports) { exports.Set("LLAMAModel", DefineClass(exports.Env(), "LLAMAModel", {})); }
};

class LLAMAContext : public Napi::ObjectWrap<LLAMAContext> {
Expand Down
22 changes: 16 additions & 6 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ type ChatCommand = {
model: string,
systemInfo: boolean,
systemPrompt: string,
wrapper: string
wrapper: string,
contextSize: number
};

export const ChatCommand: CommandModule<object, ChatCommand> = {
Expand Down Expand Up @@ -46,11 +47,17 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
choices: ["general", "llama"],
description: "Chat wrapper to use",
group: "Optional:"
})
.option("contextSize", {
type: "number",
default: 1024 * 4,
description: "Context size to use for the model",
group: "Optional:"
});
},
async handler({model, systemInfo, systemPrompt, wrapper}) {
async handler({model, systemInfo, systemPrompt, wrapper, contextSize}) {
try {
await RunChat({model, systemInfo, systemPrompt, wrapper});
await RunChat({model, systemInfo, systemPrompt, wrapper, contextSize});
} catch (err) {
console.error(err);
process.exit(1);
Expand All @@ -59,15 +66,18 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
};


async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper}: ChatCommand) {
async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper, contextSize}: ChatCommand) {
const {LlamaChatSession} = await import("../../llamaEvaluator/LlamaChatSession.js");
const {LlamaModel} = await import("../../llamaEvaluator/LlamaModel.js");
const {LlamaContext} = await import("../../llamaEvaluator/LlamaContext.js");

const model = new LlamaModel({
modelPath: modelArg
modelPath: modelArg,
contextSize
});
const context = new LlamaContext({model});
const session = new LlamaChatSession({
context: model.createContext(),
context,
printLLamaSystemInfo: systemInfo,
systemPrompt,
promptWrapper: createChatWrapper(wrapper)
Expand Down
2 changes: 1 addition & 1 deletion src/llamaEvaluator/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper
import {LlamaModel} from "./LlamaModel.js";
import {LlamaContext} from "./LlamaContext.js";

const UNKNOWN_UNICODE_CHAR = "";
const UNKNOWN_UNICODE_CHAR = "\ufffd";

export class LlamaChatSession {
private readonly _systemPrompt: string;
Expand Down
7 changes: 3 additions & 4 deletions src/llamaEvaluator/LlamaContext.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import {LLAMAContext, llamaCppNode} from "./LlamaBins.js";
import {LlamaModel} from "./LlamaModel.js";

type LlamaContextConstructorParameters = {prependBos: boolean, ctx: LLAMAContext};
export class LlamaContext {
private readonly _ctx: LLAMAContext;
private _prependBos: boolean;

/** @internal */
public constructor( {ctx, prependBos}: LlamaContextConstructorParameters ) {
this._ctx = ctx;
public constructor({model, prependBos = true}: {model: LlamaModel, prependBos?: boolean}) {
this._ctx = new LLAMAContext(model._model);
this._prependBos = prependBos;
}

Expand Down
66 changes: 52 additions & 14 deletions src/llamaEvaluator/LlamaModel.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,62 @@
import {LlamaContext} from "./LlamaContext.js";
import {LLAMAContext, llamaCppNode, LLAMAModel} from "./LlamaBins.js";
import {llamaCppNode, LLAMAModel} from "./LlamaBins.js";


export class LlamaModel {
private readonly _model: LLAMAModel;
private readonly _prependBos: boolean;
/** @internal */
public readonly _model: LLAMAModel;

public constructor({modelPath, prependBos = true}: { modelPath: string, prependBos?: boolean }) {
this._model = new LLAMAModel(modelPath);
this._prependBos = prependBos;
}

public createContext() {
return new LlamaContext({
ctx: new LLAMAContext(this._model),
prependBos: this._prependBos
});
/**
* options source:
* https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/llama.h#L102 (struct llama_context_params)
* @param {object} options
* @param {string} options.modelPath - path to the model on the filesystem
* @param {number | null} [options.seed] - If null, a random seed will be used
* @param {number} [options.contextSize] - text context size
* @param {number} [options.batchSize] - prompt processing batch size
* @param {number} [options.gpuCores] - number of layers to store in VRAM
* @param {boolean} [options.lowVram] - if true, reduce VRAM usage at the cost of performance
* @param {boolean} [options.f16Kv] - use fp16 for KV cache
* @param {boolean} [options.logitsAll] - the llama_eval() call computes all logits, not just the last one
* @param {boolean} [options.vocabOnly] - only load the vocabulary, no weights
* @param {boolean} [options.useMmap] - use mmap if possible
* @param {boolean} [options.useMlock] - force system to keep model in RAM
* @param {boolean} [options.embedding] - embedding mode only
*/
public constructor({
modelPath, seed = null, contextSize = 1024 * 4, batchSize, gpuCores,
lowVram, f16Kv, logitsAll, vocabOnly, useMmap, useMlock, embedding
}: {
modelPath: string, seed?: number | null, contextSize?: number, batchSize?: number, gpuCores?: number,
lowVram?: boolean, f16Kv?: boolean, logitsAll?: boolean, vocabOnly?: boolean, useMmap?: boolean, useMlock?: boolean,
embedding?: boolean
}) {
this._model = new LLAMAModel(modelPath, removeNullFields({
seed: seed != null ? Math.max(-1, seed) : undefined,
contextSize,
batchSize,
gpuCores,
lowVram,
f16Kv,
logitsAll,
vocabOnly,
useMmap,
useMlock,
embedding
}));
}

public static get systemInfo() {
return llamaCppNode.systemInfo();
}
}

function removeNullFields<T extends object>(obj: T): T {
const newObj: T = Object.assign({}, obj);

for (const key in obj) {
if (newObj[key] == null)
delete newObj[key];
}

return newObj;
}
14 changes: 13 additions & 1 deletion src/utils/getBin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,19 @@ export type LlamaCppNodeModule = {
};

export type LLAMAModel = {
new (modelPath: string): LLAMAModel
new (modelPath: string, params: {
seed?: number,
contextSize?: number,
batchSize?: number,
gpuCores?: number,
lowVram?: boolean,
f16Kv?: boolean,
logitsAll?: boolean,
vocabOnly?: boolean,
useMmap?: boolean,
useMlock?: boolean,
embedding?: boolean
}): LLAMAModel
};

export type LLAMAContext = {
Expand Down

0 comments on commit c76ec48

Please sign in to comment.