Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral embedding engine support #2667

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
export default function MistralAiOptions({ settings }) {
return (
<div className="w-full flex flex-col gap-y-4">
<div className="w-full flex items-center gap-[36px] mt-1.5">
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
API Key
</label>
<input
type="password"
name="MistralAiApiKey"
className="bg-theme-settings-input-bg text-white placeholder:text-theme-settings-input-placeholder text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Mistral AI API Key"
defaultValue={settings?.MistralApiKey ? "*".repeat(20) : ""}
required={true}
autoComplete="off"
spellCheck={false}
/>
</div>
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Model Preference
</label>
<select
name="EmbeddingModelPref"
required={true}
defaultValue={settings?.EmbeddingModelPref}
className="bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Available embedding models">
{[
"mistral-embed",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
</div>
</div>
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import CohereLogo from "@/media/llmprovider/cohere.png";
import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png";
import LiteLLMLogo from "@/media/llmprovider/litellm.png";
import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png";
import MistralAiLogo from "@/media/llmprovider/mistral.jpeg";

import PreLoader from "@/components/Preloader";
import ChangeWarningModal from "@/components/ChangeWarning";
Expand All @@ -33,6 +34,7 @@ import { useModal } from "@/hooks/useModal";
import ModalWrapper from "@/components/ModalWrapper";
import CTAButton from "@/components/lib/CTAButton";
import { useTranslation } from "react-i18next";
import MistralAiOptions from "@/components/EmbeddingSelection/MistralAiOptions";

const EMBEDDERS = [
{
Expand Down Expand Up @@ -100,6 +102,13 @@ const EMBEDDERS = [
options: (settings) => <LiteLLMOptions settings={settings} />,
description: "Run powerful embedding models from LiteLLM.",
},
{
name: "Mistral AI",
value: "mistral",
logo: MistralAiLogo,
options: (settings) => <MistralAiOptions settings={settings} />,
description: "Run powerful embedding models from Mistral AI.",
},
{
name: "Generic OpenAI",
value: "generic-openai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ export const EMBEDDING_ENGINE_PRIVACY = {
],
logo: VoyageAiLogo,
},
mistral: {
name: "Mistral AI",
description: [
"Data sent to Mistral AI's servers is shared according to the terms of service of https://mistral.ai.",
],
logo: MistralLogo,
},
litellm: {
name: "LiteLLM",
description: [
Expand Down
43 changes: 43 additions & 0 deletions server/utils/EmbeddingEngines/mistral/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
class MistralEmbedder {
constructor() {
if (!process.env.MISTRAL_API_KEY)
throw new Error("No Mistral API key was set.");

const { OpenAI: OpenAIApi } = require("openai");
this.openai = new OpenAIApi({
baseURL: "https://api.mistral.ai/v1",
apiKey: process.env.MISTRAL_API_KEY ?? null,
});
this.model = process.env.EMBEDDING_MODEL_PREF || "mistral-embed";
}

async embedTextInput(textInput) {
try {
const response = await this.openai.embeddings.create({
model: this.model,
input: textInput,
});
return response?.data[0]?.embedding || [];
} catch (error) {
console.error("Failed to get embedding from Mistral.", error.message);
return [];
}
}

async embedChunks(textChunks = []) {
try {
const response = await this.openai.embeddings.create({
model: this.model,
input: textChunks,
});
return response?.data?.map((emb) => emb.embedding) || [];
} catch (error) {
console.error("Failed to get embeddings from Mistral.", error.message);
return new Array(textChunks.length).fill([]);
}
}
}

module.exports = {
MistralEmbedder,
};
3 changes: 3 additions & 0 deletions server/utils/helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ function getEmbeddingEngineSelection() {
case "litellm":
const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM");
return new LiteLLMEmbedder();
case "mistral":
const { MistralEmbedder } = require("../EmbeddingEngines/mistral");
return new MistralEmbedder();
case "generic-openai":
const {
GenericOpenAiEmbedder,
Expand Down
1 change: 1 addition & 0 deletions server/utils/helpers/updateENV.js
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ function supportedEmbeddingModel(input = "") {
"voyageai",
"litellm",
"generic-openai",
"mistral",
];
return supported.includes(input)
? null
Expand Down