Skip to content

Commit

Permalink
fix(detect-emoji): only detect incomplete characters
Browse files Browse the repository at this point in the history
  • Loading branch information
ido-pluto committed Aug 16, 2023
1 parent 36a2e77 commit e626093
Showing 1 changed file with 35 additions and 39 deletions.
74 changes: 35 additions & 39 deletions src/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
import {LlamaChatPromptWrapper} from "./chatWrappers/LlamaChatPromptWrapper.js";
import {AbortError} from "./AbortError.js";

const UNKNOWN_UNICODE_CHAR = "�";

export class LlamaChatSession {
private readonly _model: LlamaModel;
Expand Down Expand Up @@ -65,46 +66,26 @@ export class LlamaChatSession {
}

private async _evalTokens(tokens: Uint32Array, onToken?: (tokens: number[]) => void, {signal}: { signal?: AbortSignal } = {}) {
const decodeTokens = (tokens: number[]) => this._model.decode(Uint32Array.from(tokens));

const stopStrings = this._promptWrapper.getStopStrings();
const stopStringIndexes = Array(stopStrings.length).fill(0);
const skippedChunksQueue: number[] = [];
const res: number[] = [];

let skipNextTokensEmoji = 0;
const decodeRes = () => this._model.decode(Uint32Array.from(res));

for await (const chunk of this._model.evaluate(tokens)) {
if (signal?.aborted)
throw new AbortError();

const tokenStr = this._model.decode(Uint32Array.from([chunk]));
let skipTokenEvent = false;

for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) {
const stopString = stopStrings[stopStringIndex];

let localShouldSkipTokenEvent = false;
for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) {
stopStringIndexes[stopStringIndex]++;
localShouldSkipTokenEvent = true;
} else {
stopStringIndexes[stopStringIndex] = 0;
localShouldSkipTokenEvent = false;
break;
}
}

if (stopStringIndexes[stopStringIndex] === stopString.length) {
return decodeRes();
}

skipTokenEvent ||= localShouldSkipTokenEvent;
}
const tokenStr = decodeTokens([chunk]);
const {shouldReturn, skipTokenEvent} = this._checkStopString(tokenStr, stopStringIndexes);

if (shouldReturn)
return decodeTokens(res);

skipNextTokensEmoji += LlamaChatSession._calculateEmojiNextLength(chunk);
if (skipTokenEvent || skipNextTokensEmoji > 0) {
skipNextTokensEmoji--;
// if the token is unknown, it means it's not complete character
if (tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent) {
skippedChunksQueue.push(chunk);
continue;
}
Expand All @@ -119,20 +100,35 @@ export class LlamaChatSession {
onToken?.([chunk]);
}

return decodeRes();
return decodeTokens(res);
}

private static _calculateEmojiNextLength(firstByte: number) {
const byteText = firstByte.toString(2);
private _checkStopString(tokenStr: string, stopStringIndexes: number[]){
const stopStrings = this._promptWrapper.getStopStrings();
let skipTokenEvent = false;

for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) {
const stopString = stopStrings[stopStringIndex];

let localShouldSkipTokenEvent = false;
for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) {
stopStringIndexes[stopStringIndex]++;
localShouldSkipTokenEvent = true;
} else {
stopStringIndexes[stopStringIndex] = 0;
localShouldSkipTokenEvent = false;
break;
}
}

if (stopStringIndexes[stopStringIndex] === stopString.length) {
return {shouldReturn: true};
}

if (byteText.startsWith("11110")) {
return 3;
} else if (byteText.startsWith("1110")) {
return 2;
} else if (byteText.startsWith("110")) {
return 1;
skipTokenEvent ||= localShouldSkipTokenEvent;
}

return 0;
return {skipTokenEvent};
}
}

0 comments on commit e626093

Please sign in to comment.