From 1b09682685f54f29957163be9b9f9fc2de3b49cc Mon Sep 17 00:00:00 2001 From: MohamedBassem Date: Sat, 12 Oct 2024 17:25:01 +0000 Subject: [PATCH] feature: Allow customizing the inference's context length --- .../dashboard/settings/AISettings.tsx | 1 + apps/workers/inference.ts | 3 +++ apps/workers/openaiWorker.ts | 16 ++++++------- apps/workers/package.json | 2 +- apps/workers/utils.ts | 9 -------- docs/docs/03-configuration.md | 21 +++++++++-------- packages/shared/config.ts | 2 ++ packages/shared/prompts.ts | 23 +++++++++++++++++-- pnpm-lock.yaml | 10 ++++---- 9 files changed, 51 insertions(+), 36 deletions(-) diff --git a/apps/web/components/dashboard/settings/AISettings.tsx b/apps/web/components/dashboard/settings/AISettings.tsx index 12f656ba..0a8db147 100644 --- a/apps/web/components/dashboard/settings/AISettings.tsx +++ b/apps/web/components/dashboard/settings/AISettings.tsx @@ -291,6 +291,7 @@ export function PromptDemo() { .filter((p) => p.appliesTo == "text" || p.appliesTo == "all") .map((p) => p.text), "\n\n", + /* context length */ 1024 /* The value here doesn't matter */, ).trim()}

Image Prompt

diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index 071f4742..41ceffd6 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -104,6 +104,9 @@ class OllamaInferenceClient implements InferenceClient { format: "json", stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, + options: { + num_ctx: serverConfig.inference.contextLength, + }, messages: [ { role: "user", content: prompt, images: image ? [image] : undefined }, ], diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 6c6104f3..d51771b2 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -23,7 +23,7 @@ import { import type { InferenceClient } from "./inference"; import { InferenceClientFactory } from "./inference"; -import { readPDFText, truncateContent } from "./utils"; +import { readPDFText } from "./utils"; const openAIResponseSchema = z.object({ tags: z.array(z.string()), @@ -102,10 +102,7 @@ async function buildPrompt( ); } - let content = bookmark.link.content; - if (content) { - content = truncateContent(content); - } + const content = bookmark.link.content; return buildTextPrompt( serverConfig.inference.inferredTagLang, prompts, @@ -113,16 +110,16 @@ async function buildPrompt( Title: ${bookmark.link.title ?? ""} Description: ${bookmark.link.description ?? ""} Content: ${content ?? ""}`, + serverConfig.inference.contextLength, ); } if (bookmark.text) { - const content = truncateContent(bookmark.text.text ?? ""); - // TODO: Ensure that the content doesn't exceed the context length of openai return buildTextPrompt( serverConfig.inference.inferredTagLang, prompts, - content, + bookmark.text.text ?? "", + serverConfig.inference.contextLength, ); } @@ -215,7 +212,8 @@ async function inferTagsFromPDF( const prompt = buildTextPrompt( serverConfig.inference.inferredTagLang, await fetchCustomPrompts(bookmark.userId, "text"), - `Content: ${truncateContent(pdfParse.text)}`, + `Content: ${pdfParse.text}`, + serverConfig.inference.contextLength, ); return inferenceClient.inferFromText(prompt); } diff --git a/apps/workers/package.json b/apps/workers/package.json index 35217c96..b8077954 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -26,7 +26,7 @@ "metascraper-title": "^5.43.4", "metascraper-twitter": "^5.43.4", "metascraper-url": "^5.43.4", - "ollama": "^0.5.0", + "ollama": "^0.5.9", "openai": "^4.67.1", "pdf2json": "^3.0.5", "pdfjs-dist": "^4.0.379", diff --git a/apps/workers/utils.ts b/apps/workers/utils.ts index 2372684e..8d297e05 100644 --- a/apps/workers/utils.ts +++ b/apps/workers/utils.ts @@ -36,12 +36,3 @@ export async function readPDFText(buffer: Buffer): Promise<{ pdfParser.parseBuffer(buffer); }); } - -export function truncateContent(content: string, length = 1500) { - let words = content.split(" "); - if (words.length > length) { - words = words.slice(0, length); - content = words.join(" "); - } - return content; -} diff --git a/docs/docs/03-configuration.md b/docs/docs/03-configuration.md index 3d674d63..98fa7a1a 100644 --- a/docs/docs/03-configuration.md +++ b/docs/docs/03-configuration.md @@ -48,16 +48,17 @@ Either `OPENAI_API_KEY` or `OLLAMA_BASE_URL` need to be set for automatic taggin - Running local models is a recent addition and not as battle tested as using OpenAI, so proceed with care (and potentially expect a bunch of inference failures). ::: -| Name | Required | Default | Description | -| ------------------------- | -------- | ----------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). | -| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. | -| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. | -| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). | -| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. | -| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). | -| INFERENCE_LANG | No | english | The language in which the tags will be generated. | -| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. | +| Name | Required | Default | Description | +| ------------------------- | -------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). | +| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. | +| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. | +| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). | +| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. | +| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). | +| INFERENCE_CONTEXT_LENGTH | No | 2048 | The max number of tokens that we'll pass to the inference model. If your content is larger than this size, it'll be truncated to fit. The larger this value, the more of the content will be used in tag inference, but the more expensive the inference will be (money-wise on openAI and resource-wise on ollama). Check the model you're using for its max supported content size. | +| INFERENCE_LANG | No | english | The language in which the tags will be generated. | +| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. | ## Crawler Configs diff --git a/packages/shared/config.ts b/packages/shared/config.ts index 44b7e26d..325d9ffa 100644 --- a/packages/shared/config.ts +++ b/packages/shared/config.ts @@ -24,6 +24,7 @@ const allEnv = z.object({ INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30), INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"), INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"), + INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048), CRAWLER_HEADLESS_BROWSER: stringBool("true"), BROWSER_WEB_URL: z.string().url().optional(), BROWSER_WEBSOCKET_URL: z.string().url().optional(), @@ -74,6 +75,7 @@ const serverConfigSchema = allEnv.transform((val) => { textModel: val.INFERENCE_TEXT_MODEL, imageModel: val.INFERENCE_IMAGE_MODEL, inferredTagLang: val.INFERENCE_LANG, + contextLength: val.INFERENCE_CONTEXT_LENGTH, }, crawler: { numWorkers: val.CRAWLER_NUM_WORKERS, diff --git a/packages/shared/prompts.ts b/packages/shared/prompts.ts index cf6d48b6..91bfba3f 100644 --- a/packages/shared/prompts.ts +++ b/packages/shared/prompts.ts @@ -1,3 +1,17 @@ +// TODO: Use a proper tokenizer +function calculateNumTokens(text: string) { + return text.split(" ").length; +} + +function truncateContent(content: string, length: number) { + let words = content.split(" "); + if (words.length > length) { + words = words.slice(0, length); + content = words.join(" "); + } + return content; +} + export function buildImagePrompt(lang: string, customPrompts: string[]) { return ` You are a bot in a read-it-later app and your responsibility is to help with automatic tagging. @@ -15,8 +29,9 @@ export function buildTextPrompt( lang: string, customPrompts: string[], content: string, + contextLength: number, ) { - return ` + const constructPrompt = (c: string) => ` You are a bot in a read-it-later app and your responsibility is to help with automatic tagging. Please analyze the text between the sentences "CONTENT START HERE" and "CONTENT END HERE" and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are: - Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. @@ -27,7 +42,11 @@ Please analyze the text between the sentences "CONTENT START HERE" and "CONTENT - If there are no good tags, leave the array empty. ${customPrompts && customPrompts.map((p) => `- ${p}`).join("\n")} CONTENT START HERE -${content} +${c} CONTENT END HERE You must respond in JSON with the key "tags" and the value is an array of string tags.`; + + const promptSize = calculateNumTokens(constructPrompt("")); + const truncatedContent = truncateContent(content, contextLength - promptSize); + return constructPrompt(truncatedContent); } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index cb4e0106..eade6d67 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -738,8 +738,8 @@ importers: specifier: ^5.43.4 version: 5.45.0 ollama: - specifier: ^0.5.0 - version: 0.5.0 + specifier: ^0.5.9 + version: 0.5.9 openai: specifier: ^4.67.1 version: 4.67.1(zod@3.22.4) @@ -9412,8 +9412,8 @@ packages: resolution: {integrity: sha512-IF4PcGgzAr6XXSff26Sk/+P4KZFJVuHAJZj3wgO3vX2bMdNVp/QXTP3P7CEm9V1IdG8lDLY3HhiqpsE/nOwpPw==} engines: {node: ^10.13.0 || >=12.0.0} - ollama@0.5.0: - resolution: {integrity: sha512-CRtRzsho210EGdK52GrUMohA2pU+7NbgEaBG3DcYeRmvQthDO7E2LHOkLlUUeaYUlNmEd8icbjC02ug9meSYnw==} + ollama@0.5.9: + resolution: {integrity: sha512-F/KZuDRC+ZsVCuMvcOYuQ6zj42/idzCkkuknGyyGVmNStMZ/sU3jQpvhnl4SyC0+zBzLiKNZJnJeuPFuieWZvQ==} on-finished@2.3.0: resolution: {integrity: sha512-ikqdkGAAyf/X/gPhXGvfgAytDZtDbr+bkNUJ0N9h5MI/dmdgCs3l6hoHrcUv41sRKew3jIwrp4qQDXiK99Utww==} @@ -25611,7 +25611,7 @@ snapshots: oidc-token-hash@5.0.3: dev: false - ollama@0.5.0: + ollama@0.5.9: dependencies: whatwg-fetch: 3.6.20 dev: false