Skip to content

Commit

Permalink
feat: Add support for embeddings in the inference interface (#403)
Browse files Browse the repository at this point in the history
* support embeddings generation in inference.ts

(cherry picked from commit 9ae8773)

* make AI worker generate embeddings for text bookmark

* make AI worker generate embeddings for text bookmark

* fix unintentional change -- inference image model

* support embeddings for PDF bookmarks

* Upgrade drizzle-kit

Existing version is not working with the upgraded version of drizzle-orm.

I removed the "driver" to the match the new schema of the Config.

Quoting from their Config:
* `driver` - optional param that is responsible for explicitly providing a driver to use when accessing a database
 * *Possible values*: `aws-data-api`, `d1-http`, `expo`, `turso`, `pglite`
 * If you don't use AWS Data API, D1, Turso or Expo - ypu don't need this driver. You can check a driver strategy choice here: https://orm.

* fix formatting and lint

* add comments about truncate content

* Revert "Upgrade drizzle-kit"

This reverts commit 08a02c8.

* revert keep alive field in Ollama

* change the interface to accept multiple inputs

* docs

---------

Co-authored-by: Mohamed Bassem <me@mbassem.com>
  • Loading branch information
medo and MohamedBassem committed Dec 29, 2024
1 parent 225d855 commit c89b0c5
Showing 3 changed files with 48 additions and 11 deletions.
23 changes: 12 additions & 11 deletions docs/docs/03-configuration.md
Original file line number Diff line number Diff line change
@@ -48,17 +48,18 @@ Either `OPENAI_API_KEY` or `OLLAMA_BASE_URL` need to be set for automatic taggin
- You might want to tune the `INFERENCE_CONTEXT_LENGTH` as the default is quite small. The larger the value, the better the quality of the tags, but the more expensive the inference will be (money-wise on OpenAI and resource-wise on ollama).
:::

| Name | Required | Default | Description |
| ------------------------- | -------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). |
| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
| INFERENCE_CONTEXT_LENGTH | No | 2048 | The max number of tokens that we'll pass to the inference model. If your content is larger than this size, it'll be truncated to fit. The larger this value, the more of the content will be used in tag inference, but the more expensive the inference will be (money-wise on openAI and resource-wise on ollama). Check the model you're using for its max supported content size. |
| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |
| Name | Required | Default | Description |
| ------------------------- | -------- | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). |
| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
| EMBEDDING_TEXT_MODEL | No | text-embedding-3-small | The model to be used for generating embeddings for the text. |
| INFERENCE_CONTEXT_LENGTH | No | 2048 | The max number of tokens that we'll pass to the inference model. If your content is larger than this size, it'll be truncated to fit. The larger this value, the more of the content will be used in tag inference, but the more expensive the inference will be (money-wise on openAI and resource-wise on ollama). Check the model you're using for its max supported content size. |
| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |

:::info

4 changes: 4 additions & 0 deletions packages/shared/config.ts
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ const allEnv = z.object({
INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30),
INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"),
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"),
INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048),
OCR_CACHE_DIR: z.string().optional(),
OCR_LANGS: z
@@ -90,6 +91,9 @@ const serverConfigSchema = allEnv.transform((val) => {
inferredTagLang: val.INFERENCE_LANG,
contextLength: val.INFERENCE_CONTEXT_LENGTH,
},
embedding: {
textModel: val.EMBEDDING_TEXT_MODEL,
},
crawler: {
numWorkers: val.CRAWLER_NUM_WORKERS,
headlessBrowser: val.CRAWLER_HEADLESS_BROWSER,
32 changes: 32 additions & 0 deletions packages/shared/inference.ts
Original file line number Diff line number Diff line change
@@ -9,6 +9,10 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}

export interface EmbeddingResponse {
embeddings: number[][];
}

export interface InferenceOptions {
json: boolean;
}
@@ -28,6 +32,7 @@ export interface InferenceClient {
image: string,
opts: InferenceOptions,
): Promise<InferenceResponse>;
generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}

export class InferenceClientFactory {
@@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient {
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}

async generateEmbeddingFromText(
inputs: string[],
): Promise<EmbeddingResponse> {
const model = serverConfig.embedding.textModel;
const embedResponse = await this.openAI.embeddings.create({
model: model,
input: inputs,
});
const embedding2D: number[][] = embedResponse.data.map(
(embedding: OpenAI.Embedding) => embedding.embedding,
);
return { embeddings: embedding2D };
}
}

class OllamaInferenceClient implements InferenceClient {
@@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient {
opts,
);
}

async generateEmbeddingFromText(
inputs: string[],
): Promise<EmbeddingResponse> {
const embedding = await this.ollama.embed({
model: serverConfig.embedding.textModel,
input: inputs,
// Truncate the input to fit into the model's max token limit,
// in the future we want to add a way to split the input into multiple parts.
truncate: true,
});
return { embeddings: embedding.embeddings };
}
}

0 comments on commit c89b0c5

Please sign in to comment.