Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: load conversation history into a LlamaChatSession #51

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,34 @@ const a2 = await session.prompt(q2);
console.log("AI: " + a2);
```

##### Load existing conversation history
```typescript
import {fileURLToPath} from "url";
import path from "path";
import {LlamaModel, LlamaContext, LlamaChatSession} from "node-llama-cpp";

const __dirname = path.dirname(fileURLToPath(import.meta.url));

const model = new LlamaModel({
modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf")
})
const context = new LlamaContext({model});
const session = new LlamaChatSession({
context,
conversationHistory: [{
prompt: `Remember the number 6 as "The number"`,
response: "OK. I'll remember it"
}]
});


const q2 = 'What is "The number"?';
console.log("User: " + q2);

const a2 = await session.prompt(q2);
console.log("AI: " + a2);
```

#### Raw
```typescript
import {fileURLToPath} from "url";
Expand Down
9 changes: 9 additions & 0 deletions src/ChatPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ export abstract class ChatPromptWrapper {
public getStopStrings(): string[] {
return [];
}

public getDefaultStopString(): string {
const stopString = this.getStopStrings()[0];

if (stopString == null || stopString.length === 0)
throw new Error(`Prompt wrapper "${this.wrapperName}" has no stop strings`);

return stopString;
}
}
4 changes: 4 additions & 0 deletions src/chatWrappers/ChatMLPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ export class ChatMLPromptWrapper extends ChatPromptWrapper {
public override getStopStrings(): string[] {
return ["<|im_end|>"];
}

public override getDefaultStopString(): string {
return "<|im_end|>";
}
}
6 changes: 5 additions & 1 deletion src/chatWrappers/GeneralChatPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ export class GeneralChatPromptWrapper extends ChatPromptWrapper {
];
}

public override getDefaultStopString(): string {
return `\n\n### ${this._instructionName}`;
}

private _getPromptPrefix(lastStopString: string | null, lastStopStringSuffix: string | null) {
return getTextCompletion(
lastStopString === "<end>"
? lastStopStringSuffix
: (lastStopString + (lastStopStringSuffix ?? "")),
: ((lastStopString ?? "") + (lastStopStringSuffix ?? "")),
[
`\n\n### ${this._instructionName}:\n\n`,
`### ${this._instructionName}:\n\n`
Expand Down
4 changes: 4 additions & 0 deletions src/chatWrappers/LlamaChatPromptWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ export class LlamaChatPromptWrapper extends ChatPromptWrapper {
public override getStopStrings(): string[] {
return ["</s>"];
}

public override getDefaultStopString(): string {
return "</s>";
}
}
71 changes: 71 additions & 0 deletions src/chatWrappers/generateContextTextFromConversationHistory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
import {defaultChatSystemPrompt} from "../config.js";
import {ConversationInteraction} from "../types.js";


/**
* Generate context text to load into a model context from a conversation history.
* @param {ChatPromptWrapper} chatPromptWrapper
* @param {ConversationInteraction[]} conversationHistory
* @param {object} [options]
* @param {string} [options.systemPrompt]
* @param {number} [options.currentPromptIndex]
* @param {string | null} [options.lastStopString]
* @param {string | null} [options.lastStopStringSuffix]
* @returns {{text: string, stopString: (string | null), stopStringSuffix: (string | null)}}
*/
export function generateContextTextFromConversationHistory(
chatPromptWrapper: ChatPromptWrapper,
conversationHistory: readonly ConversationInteraction[],
{
systemPrompt = defaultChatSystemPrompt, currentPromptIndex = 0, lastStopString = null, lastStopStringSuffix = null
}: {
systemPrompt?: string, currentPromptIndex?: number, lastStopString?: string | null, lastStopStringSuffix?: string | null
} = {}
): {
text: string;
stopString: string | null;
stopStringSuffix: string | null;
} {
let res = "";

for (let i = 0; i < conversationHistory.length; i++) {
const interaction = conversationHistory[i];
const wrappedPrompt = chatPromptWrapper.wrapPrompt(interaction.prompt, {
systemPrompt,
promptIndex: currentPromptIndex,
lastStopString,
lastStopStringSuffix
});
const stopStrings = chatPromptWrapper.getStopStrings();
const defaultStopString = chatPromptWrapper.getDefaultStopString();
const stopStringsToCheckInResponse = new Set([...stopStrings, defaultStopString]);

currentPromptIndex++;
lastStopString = null;
lastStopStringSuffix = null;

res += wrappedPrompt;

for (const stopString of stopStringsToCheckInResponse) {
if (interaction.response.includes(stopString)) {
console.error(
`Stop string "${stopString}" was found in model response of conversation interaction index ${i}`,
{interaction, stopString}
);
throw new Error("A stop string cannot be in a conversation history interaction model response");
}
}

res += interaction.response;
res += defaultStopString;
lastStopString = defaultStopString;
lastStopStringSuffix = "";
}

return {
text: res,
stopString: lastStopString,
stopStringSuffix: lastStopStringSuffix
};
}
3 changes: 2 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.
import {ChatMLPromptWrapper} from "./chatWrappers/ChatMLPromptWrapper.js";
import {getChatWrapperByBos} from "./chatWrappers/createChatWrapperByBos.js";

import {type Token} from "./types.js";
import {type ConversationInteraction, type Token} from "./types.js";


export {
Expand All @@ -22,6 +22,7 @@ export {
type LlamaContextOptions,
LlamaChatSession,
type LlamaChatSessionOptions,
type ConversationInteraction,
AbortError,
ChatPromptWrapper,
EmptyChatPromptWrapper,
Expand Down
42 changes: 38 additions & 4 deletions src/llamaEvaluator/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
import {AbortError} from "../AbortError.js";
import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper.js";
import {getChatWrapperByBos} from "../chatWrappers/createChatWrapperByBos.js";
import {Token} from "../types.js";
import {ConversationInteraction, Token} from "../types.js";
import {generateContextTextFromConversationHistory} from "../chatWrappers/generateContextTextFromConversationHistory.js";
import {LlamaModel} from "./LlamaModel.js";
import {LlamaContext} from "./LlamaContext.js";

Expand All @@ -15,7 +16,10 @@ export type LlamaChatSessionOptions = {
context: LlamaContext,
printLLamaSystemInfo?: boolean,
promptWrapper?: ChatPromptWrapper | "auto",
systemPrompt?: string
systemPrompt?: string,

/** Conversation history to load into the context to continue an existing conversation */
conversationHistory?: readonly ConversationInteraction[]
};

export class LlamaChatSession {
Expand All @@ -26,17 +30,22 @@ export class LlamaChatSession {
private _initialized: boolean = false;
private _lastStopString: string | null = null;
private _lastStopStringSuffix: string | null = null;
private _conversationHistoryToLoad: readonly ConversationInteraction[] | null = null;
private readonly _ctx: LlamaContext;

public constructor({
context,
printLLamaSystemInfo = false,
promptWrapper = new GeneralChatPromptWrapper(),
systemPrompt = defaultChatSystemPrompt
systemPrompt = defaultChatSystemPrompt,
conversationHistory
}: LlamaChatSessionOptions) {
this._ctx = context;
this._printLLamaSystemInfo = printLLamaSystemInfo;
this._systemPrompt = systemPrompt;
this._conversationHistoryToLoad = (conversationHistory != null && conversationHistory.length > 0)
? conversationHistory
: null;

if (promptWrapper === "auto") {
const chatWrapper = getChatWrapperByBos(context.getBosString());
Expand Down Expand Up @@ -76,7 +85,32 @@ export class LlamaChatSession {
await this.init();

return await withLock(this, "prompt", async () => {
const promptText = this._promptWrapper.wrapPrompt(prompt, {
let promptText = "";

if (this._promptIndex == 0 && this._conversationHistoryToLoad != null) {
const {text, stopString, stopStringSuffix} =
generateContextTextFromConversationHistory(this._promptWrapper, this._conversationHistoryToLoad, {
systemPrompt: this._systemPrompt,
currentPromptIndex: this._promptIndex,
lastStopString: this._lastStopString,
lastStopStringSuffix: this._promptIndex == 0
? (
this._ctx.prependBos
? this._ctx.getBosString()
: null
)
: this._lastStopStringSuffix
});

promptText += text;
this._lastStopString = stopString;
this._lastStopStringSuffix = stopStringSuffix;
this._promptIndex += this._conversationHistoryToLoad.length;

this._conversationHistoryToLoad = null;
}

promptText += this._promptWrapper.wrapPrompt(prompt, {
systemPrompt: this._systemPrompt,
promptIndex: this._promptIndex,
lastStopString: this._lastStopString,
Expand Down
18 changes: 11 additions & 7 deletions src/llamaEvaluator/LlamaContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ export type LlamaContextOptions = {

export class LlamaContext {
private readonly _ctx: LLAMAContext;
private _prependBos: boolean;
private readonly _prependBos: boolean;
private _prependTokens: Token[];

public constructor({model, grammar, prependBos = true}: LlamaContextOptions) {
this._ctx = new LLAMAContext(model._model, removeNullFields({
grammar: grammar?._grammar
}));
this._prependBos = prependBos;
this._prependTokens = [];

if (prependBos) {
this._prependTokens.unshift(this._ctx.tokenBos());
}
}

public encode(text: string): Uint32Array {
Expand Down Expand Up @@ -115,19 +121,18 @@ export class LlamaContext {
return this._ctx.getTokenString(nlToken);
}

public getContextSize() {
public getContextSize(): number {
return this._ctx.getContextSize();
}

public async *evaluate(tokens: Uint32Array): AsyncGenerator<Token, void> {
let evalTokens = tokens;

if (this._prependBos) {
const tokenArray: Token[] = Array.from(tokens);
tokenArray.unshift(this._ctx.tokenBos());
if (this._prependTokens.length > 0) {
const tokenArray: Token[] = this._prependTokens.concat(Array.from(tokens));

evalTokens = Uint32Array.from(tokenArray);
this._prependBos = false;
this._prependTokens = [];
}

// eslint-disable-next-line no-constant-condition
Expand All @@ -145,5 +150,4 @@ export class LlamaContext {
evalTokens = Uint32Array.from([nextToken]);
}
}

}
5 changes: 5 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
export type Token = number;

export type ConversationInteraction = {
prompt: string,
response: string
};