diff --git a/src/LlamaChatSession.ts b/src/LlamaChatSession.ts index 382ca7c6..64ac7ed6 100644 --- a/src/LlamaChatSession.ts +++ b/src/LlamaChatSession.ts @@ -5,6 +5,7 @@ import {ChatPromptWrapper} from "./ChatPromptWrapper.js"; import {LlamaChatPromptWrapper} from "./chatWrappers/LlamaChatPromptWrapper.js"; import {AbortError} from "./AbortError.js"; +const UNKNOWN_UNICODE_CHAR = "�"; export class LlamaChatSession { private readonly _model: LlamaModel; @@ -65,46 +66,26 @@ export class LlamaChatSession { } private async _evalTokens(tokens: Uint32Array, onToken?: (tokens: number[]) => void, {signal}: { signal?: AbortSignal } = {}) { + const decodeTokens = (tokens: number[]) => this._model.decode(Uint32Array.from(tokens)); + const stopStrings = this._promptWrapper.getStopStrings(); const stopStringIndexes = Array(stopStrings.length).fill(0); const skippedChunksQueue: number[] = []; const res: number[] = []; - let skipNextTokensEmoji = 0; - const decodeRes = () => this._model.decode(Uint32Array.from(res)); for await (const chunk of this._model.evaluate(tokens)) { if (signal?.aborted) throw new AbortError(); - const tokenStr = this._model.decode(Uint32Array.from([chunk])); - let skipTokenEvent = false; - - for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) { - const stopString = stopStrings[stopStringIndex]; - - let localShouldSkipTokenEvent = false; - for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) { - if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) { - stopStringIndexes[stopStringIndex]++; - localShouldSkipTokenEvent = true; - } else { - stopStringIndexes[stopStringIndex] = 0; - localShouldSkipTokenEvent = false; - break; - } - } - - if (stopStringIndexes[stopStringIndex] === stopString.length) { - return decodeRes(); - } - - skipTokenEvent ||= localShouldSkipTokenEvent; - } + const tokenStr = decodeTokens([chunk]); + const {shouldReturn, skipTokenEvent} = this._checkStopString(tokenStr, stopStringIndexes); + + if (shouldReturn) + return decodeTokens(res); - skipNextTokensEmoji += LlamaChatSession._calculateEmojiNextLength(chunk); - if (skipTokenEvent || skipNextTokensEmoji > 0) { - skipNextTokensEmoji--; + // if the token is unknown, it means it's not complete character + if (tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent) { skippedChunksQueue.push(chunk); continue; } @@ -119,20 +100,35 @@ export class LlamaChatSession { onToken?.([chunk]); } - return decodeRes(); + return decodeTokens(res); } - private static _calculateEmojiNextLength(firstByte: number) { - const byteText = firstByte.toString(2); + private _checkStopString(tokenStr: string, stopStringIndexes: number[]){ + const stopStrings = this._promptWrapper.getStopStrings(); + let skipTokenEvent = false; + + for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) { + const stopString = stopStrings[stopStringIndex]; + + let localShouldSkipTokenEvent = false; + for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) { + if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) { + stopStringIndexes[stopStringIndex]++; + localShouldSkipTokenEvent = true; + } else { + stopStringIndexes[stopStringIndex] = 0; + localShouldSkipTokenEvent = false; + break; + } + } + + if (stopStringIndexes[stopStringIndex] === stopString.length) { + return {shouldReturn: true}; + } - if (byteText.startsWith("11110")) { - return 3; - } else if (byteText.startsWith("1110")) { - return 2; - } else if (byteText.startsWith("110")) { - return 1; + skipTokenEvent ||= localShouldSkipTokenEvent; } - return 0; + return {skipTokenEvent}; } }