Skip to content

Commit

Permalink
Improve AWS Bedrock embeddings and token usage visibility
Browse files Browse the repository at this point in the history
* Add additional possibility to use Cohere embeddings for AWS Bedrock embeddings module. Release version 1.x

* Add more information about token usage to the backend controller.
  • Loading branch information
Xantier committed Oct 3, 2024
1 parent 74c738f commit d2b38ac
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 757 deletions.
8 changes: 8 additions & 0 deletions .changeset/clever-parrots-obey.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@roadiehq/rag-ai-backend-embeddings-aws': major
'@roadiehq/rag-ai-backend': minor
---

Add additional possibility to use Cohere embeddings for AWS Bedrock embeddings module. Release version 1.x

Add more information about token usage to the backend controller.
2 changes: 1 addition & 1 deletion plugins/backend/rag-ai-backend-embeddings-aws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"postpack": "backstage-cli package postpack"
},
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.474.0",
"@aws-sdk/client-bedrock-runtime": "^3.602.0",
"@aws-sdk/types": "^3.468.0",
"@backstage/backend-common": "^0.24.0",
"@backstage/catalog-client": "^1.6.6",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2024 Larder Software Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import {
BedrockRuntimeClient,
InvokeModelCommand,
} from '@aws-sdk/client-bedrock-runtime';
import { Embeddings } from '@langchain/core/embeddings';
import { BedrockEmbeddingsParams } from '@langchain/aws';

export class BedrockCohereEmbeddings
extends Embeddings
implements BedrockEmbeddingsParams
{
model: string;

client: BedrockRuntimeClient;

batchSize = 512;

constructor(fields?: BedrockEmbeddingsParams) {
super(fields ?? {});

this.model = fields?.model ?? 'cohere.embed-english-v3';

this.client =
fields?.client ??
new BedrockRuntimeClient({
region: fields?.region,
credentials: fields?.credentials,
});
}

/**
* Embeds an array of documents using the Bedrock model.
* @param documents The array of documents to be embedded.
* @param inputType The input type for the embedding process.
* @returns A promise that resolves to a 2D array of embeddings.
* @throws If an error occurs while embedding documents with Bedrock.
*/
protected async embed(
documents: string[],
inputType: string,
): Promise<number[][]> {
return this.caller.call(async () => {
const batchSize = 66; // Max 66 documents per batch
const batches = [];

for (let i = 0; i < documents.length; i += batchSize) {
batches.push(documents.slice(i, i + batchSize));
}

const results: number[][] = [];

try {
for (const batch of batches) {
const res = await this.client.send(
new InvokeModelCommand({
modelId: this.model,
body: JSON.stringify({
texts: batch.map(doc => doc.replace(/\n+/g, ' ')),
input_type: inputType,
}),
contentType: 'application/json',
accept: 'application/json',
}),
);

const body = new TextDecoder().decode(res.body);
const embeddings = JSON.parse(body).embeddings;
results.push(...embeddings);
}

return results;
} catch (e) {
console.error({
error: e,
});
if (e instanceof Error) {
throw new Error(
`An error occurred while embedding documents with Bedrock: ${e.message}`,
);
}

throw new Error(
'An error occurred while embedding documents with Bedrock',
);
}
});
}

async embedQuery(document: string): Promise<number[]> {
return this.embed([document], 'search_query').then(
embeddings => embeddings[0],
);
}

async embedDocuments(documents: string[]): Promise<number[][]> {
return this.embed(documents, 'search_document');
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
DefaultVectorAugmentationIndexer,
RoadieEmbeddingsConfig,
} from '@roadiehq/rag-ai-backend-retrieval-augmenter';
import { BedrockCohereEmbeddings } from './BedrockCohereEmbeddings';

export type BedrockConfig = {
modelName: string;
Expand All @@ -36,11 +37,20 @@ export class RoadieBedrockAugmenter extends DefaultVectorAugmentationIndexer {
tokenManager: TokenManager;
},
) {
const embeddings = new BedrockEmbeddings({
region: config.options.region,
credentials: config.options.credentials,
model: config.bedrockConfig.modelName,
});
super({ ...config, embeddings });
if (config.bedrockConfig.modelName.includes('cohere')) {
const embeddings = new BedrockCohereEmbeddings({
region: config.options.region,
credentials: config.options.credentials,
model: config.bedrockConfig.modelName,
});
super({ ...config, embeddings });
} else {
const embeddings = new BedrockEmbeddings({
region: config.options.region,
credentials: config.options.credentials,
model: config.bedrockConfig.modelName,
});
super({ ...config, embeddings });
}
}
}
20 changes: 20 additions & 0 deletions plugins/backend/rag-ai-backend/src/service/RagAiController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ import {
// @ts-ignore
import type compression from 'compression';

type UsageMetadata = {
total_tokens: number;
output_tokens: number;
input_tokens: number;
};

export class RagAiController {
private static instance: RagAiController;
private readonly llmService: LlmService;
Expand Down Expand Up @@ -134,8 +140,18 @@ export class RagAiController {
res.write(embeddingsEvent + embeddingsData);

const stream = await this.llmService.query(embeddingDocs, query);
const usage = { input_tokens: 0, output_tokens: 0, total_tokens: 0 };

for await (const chunk of stream) {
if (typeof chunk !== 'string' && 'usage_metadata' in chunk) {
usage.input_tokens +=
(chunk.usage_metadata as UsageMetadata)?.input_tokens ?? 0;
usage.output_tokens +=
(chunk.usage_metadata as UsageMetadata)?.output_tokens ?? 0;
usage.total_tokens +=
(chunk.usage_metadata as UsageMetadata)?.total_tokens ?? 0;
}

const text =
typeof chunk === 'string' ? chunk : (chunk.content as string);
const event = `event: response\n`;
Expand All @@ -144,6 +160,10 @@ export class RagAiController {
res.flush?.();
}

this.logger.info(
`Produced response with token usage: ${JSON.stringify(usage)}`,
);
res.write(`event: usage\n` + `data: ${JSON.stringify(usage)}\n\n`);
res.end();
};

Expand Down
Loading

0 comments on commit d2b38ac

Please sign in to comment.