Skip to content

Commit

Permalink
support 16k model
Browse files Browse the repository at this point in the history
  • Loading branch information
joyqi committed Jun 30, 2023
1 parent b49b17b commit f7e5e22
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 61 deletions.
110 changes: 53 additions & 57 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ const argv = yargs(hideBin(process.argv))
type as any,
format as any,
key,
model as any,
model,
prompt,
_[0] as string,
output,
Expand Down
18 changes: 15 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { LLMChain } from 'langchain/chains';
import { existsSync, readFileSync, writeFileSync } from 'fs';
import { StructureType, detectStructureType, getStructure } from './structure';
import { FormatterType, detectFormatterType, getFormatter } from './formatter';
import { TiktokenModel, encoding_for_model } from '@dqbd/tiktoken';
import { Tiktoken, TiktokenModel, encoding_for_model } from '@dqbd/tiktoken';
import ora from 'ora';
import languageEncoding from 'detect-file-encoding-and-language';
import detectIndent from 'detect-indent';
Expand Down Expand Up @@ -91,11 +91,23 @@ async function detectFile(file: string): Promise<[BufferEncoding, string]> {
return [enc, lang];
}

function getEncModel(modelName: string): Tiktoken {
if (modelName.match(/^gpt\-3\.5\-turbo\-[0-9]{4}$/)) {
modelName = 'gpt-3.5-turbo-0301';
} else if (modelName.match(/^gpt\-3\.5\-turbo\-16k(\-[0-9]{4})?$/)) {
modelName = 'gpt-3.5-turbo';
} else if (modelName.match(/^gpt\-4(\-[0-9]{4})?$/)) {
modelName = 'gpt-4';
}

return encoding_for_model(modelName as TiktokenModel);
}

export async function translate<T extends StructureType, F extends FormatterType>(
type: T['type'] | 'auto',
format: F['type'] | 'auto',
openAIApiKey: string,
model: TiktokenModel,
model: string,
prompt: string | null,
srcFile: string,
dstFile: string,
Expand All @@ -110,7 +122,7 @@ export async function translate<T extends StructureType, F extends FormatterType
});

const chain = buildChain(chat, 'You are a helpful assistant that translates json formatted data from {input_language} to {output_language}. {prompt}');
const enc = encoding_for_model(model);
const enc = getEncModel(model);

if (format === 'auto') {
format = detectFormatterType(srcFile);
Expand Down

0 comments on commit f7e5e22

Please sign in to comment.