Skip to content

Commit

Permalink
feature: Allow customizing the inference's context length
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamedBassem committed Oct 12, 2024
1 parent c16173e commit 1b09682
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 36 deletions.
1 change: 1 addition & 0 deletions apps/web/components/dashboard/settings/AISettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ export function PromptDemo() {
.filter((p) => p.appliesTo == "text" || p.appliesTo == "all")
.map((p) => p.text),
"\n<CONTENT_HERE>\n",
/* context length */ 1024 /* The value here doesn't matter */,
).trim()}
</code>
<p>Image Prompt</p>
Expand Down
3 changes: 3 additions & 0 deletions apps/workers/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class OllamaInferenceClient implements InferenceClient {
format: "json",
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {
num_ctx: serverConfig.inference.contextLength,
},
messages: [
{ role: "user", content: prompt, images: image ? [image] : undefined },
],
Expand Down
16 changes: 7 additions & 9 deletions apps/workers/openaiWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {

import type { InferenceClient } from "./inference";
import { InferenceClientFactory } from "./inference";
import { readPDFText, truncateContent } from "./utils";
import { readPDFText } from "./utils";

const openAIResponseSchema = z.object({
tags: z.array(z.string()),
Expand Down Expand Up @@ -102,27 +102,24 @@ async function buildPrompt(
);
}

let content = bookmark.link.content;
if (content) {
content = truncateContent(content);
}
const content = bookmark.link.content;
return buildTextPrompt(
serverConfig.inference.inferredTagLang,
prompts,
`URL: ${bookmark.link.url}
Title: ${bookmark.link.title ?? ""}
Description: ${bookmark.link.description ?? ""}
Content: ${content ?? ""}`,
serverConfig.inference.contextLength,
);
}

if (bookmark.text) {
const content = truncateContent(bookmark.text.text ?? "");
// TODO: Ensure that the content doesn't exceed the context length of openai
return buildTextPrompt(
serverConfig.inference.inferredTagLang,
prompts,
content,
bookmark.text.text ?? "",
serverConfig.inference.contextLength,
);
}

Expand Down Expand Up @@ -215,7 +212,8 @@ async function inferTagsFromPDF(
const prompt = buildTextPrompt(
serverConfig.inference.inferredTagLang,
await fetchCustomPrompts(bookmark.userId, "text"),
`Content: ${truncateContent(pdfParse.text)}`,
`Content: ${pdfParse.text}`,
serverConfig.inference.contextLength,
);
return inferenceClient.inferFromText(prompt);
}
Expand Down
2 changes: 1 addition & 1 deletion apps/workers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"metascraper-title": "^5.43.4",
"metascraper-twitter": "^5.43.4",
"metascraper-url": "^5.43.4",
"ollama": "^0.5.0",
"ollama": "^0.5.9",
"openai": "^4.67.1",
"pdf2json": "^3.0.5",
"pdfjs-dist": "^4.0.379",
Expand Down
9 changes: 0 additions & 9 deletions apps/workers/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,3 @@ export async function readPDFText(buffer: Buffer): Promise<{
pdfParser.parseBuffer(buffer);
});
}

export function truncateContent(content: string, length = 1500) {
let words = content.split(" ");
if (words.length > length) {
words = words.slice(0, length);
content = words.join(" ");
}
return content;
}
21 changes: 11 additions & 10 deletions docs/docs/03-configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,17 @@ Either `OPENAI_API_KEY` or `OLLAMA_BASE_URL` need to be set for automatic taggin
- Running local models is a recent addition and not as battle tested as using OpenAI, so proceed with care (and potentially expect a bunch of inference failures).
:::

| Name | Required | Default | Description |
| ------------------------- | -------- | ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). |
| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |
| Name | Required | Default | Description |
| ------------------------- | -------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). |
| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
| INFERENCE_CONTEXT_LENGTH | No | 2048 | The max number of tokens that we'll pass to the inference model. If your content is larger than this size, it'll be truncated to fit. The larger this value, the more of the content will be used in tag inference, but the more expensive the inference will be (money-wise on openAI and resource-wise on ollama). Check the model you're using for its max supported content size. |
| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |

## Crawler Configs

Expand Down
2 changes: 2 additions & 0 deletions packages/shared/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const allEnv = z.object({
INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30),
INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"),
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048),
CRAWLER_HEADLESS_BROWSER: stringBool("true"),
BROWSER_WEB_URL: z.string().url().optional(),
BROWSER_WEBSOCKET_URL: z.string().url().optional(),
Expand Down Expand Up @@ -74,6 +75,7 @@ const serverConfigSchema = allEnv.transform((val) => {
textModel: val.INFERENCE_TEXT_MODEL,
imageModel: val.INFERENCE_IMAGE_MODEL,
inferredTagLang: val.INFERENCE_LANG,
contextLength: val.INFERENCE_CONTEXT_LENGTH,
},
crawler: {
numWorkers: val.CRAWLER_NUM_WORKERS,
Expand Down
23 changes: 21 additions & 2 deletions packages/shared/prompts.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// TODO: Use a proper tokenizer
function calculateNumTokens(text: string) {
return text.split(" ").length;
}

function truncateContent(content: string, length: number) {
let words = content.split(" ");
if (words.length > length) {
words = words.slice(0, length);
content = words.join(" ");
}
return content;
}

export function buildImagePrompt(lang: string, customPrompts: string[]) {
return `
You are a bot in a read-it-later app and your responsibility is to help with automatic tagging.
Expand All @@ -15,8 +29,9 @@ export function buildTextPrompt(
lang: string,
customPrompts: string[],
content: string,
contextLength: number,
) {
return `
const constructPrompt = (c: string) => `
You are a bot in a read-it-later app and your responsibility is to help with automatic tagging.
Please analyze the text between the sentences "CONTENT START HERE" and "CONTENT END HERE" and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are:
- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres.
Expand All @@ -27,7 +42,11 @@ Please analyze the text between the sentences "CONTENT START HERE" and "CONTENT
- If there are no good tags, leave the array empty.
${customPrompts && customPrompts.map((p) => `- ${p}`).join("\n")}
CONTENT START HERE
${content}
${c}
CONTENT END HERE
You must respond in JSON with the key "tags" and the value is an array of string tags.`;

const promptSize = calculateNumTokens(constructPrompt(""));
const truncatedContent = truncateContent(content, contextLength - promptSize);
return constructPrompt(truncatedContent);
}
10 changes: 5 additions & 5 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1b09682

Please sign in to comment.