From 89b947dac6cf30d5869ba82ad9a83e187c6cdfb4 Mon Sep 17 00:00:00 2001 From: Diego Ramp Date: Fri, 20 Dec 2024 17:27:33 +0100 Subject: [PATCH] config: add keys to override config for chat-, embedding-, and image-model: api-key, ad-token, api-version Closes #1154 --- .gitignore | 3 ++ .../openai/runtime/AzureOpenAiRecorder.java | 25 +++++----- .../runtime/config/ChatModelConfig.java | 17 +++++++ .../runtime/config/EmbeddingModelConfig.java | 17 +++++++ .../runtime/config/ImageModelConfig.java | 17 +++++++ .../AzureOpenAiRecorderEndpointTests.java | 46 +++++++++++++++++++ 6 files changed, 111 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index babd74774..fe6c136d7 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,9 @@ release.properties # Quarkus CLI .quarkus +# dotenv +.env + #Dolphin .directory /samples/chatbot/dev.sh diff --git a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java index 40cc8e252..8e80d2ddd 100644 --- a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java +++ b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java @@ -30,7 +30,6 @@ import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiImageModel; import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel; import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig; -import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig; import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig.AzureAiConfig.EndpointType; import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient; @@ -58,17 +57,16 @@ public Function, ChatLanguageModel LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName); if (azureAiConfig.enableIntegration()) { - ChatModelConfig chatModelConfig = azureAiConfig.chatModel(); - String apiKey = azureAiConfig.apiKey().orElse(null); - String adToken = azureAiConfig.adToken().orElse(null); - + var chatModelConfig = azureAiConfig.chatModel(); + var apiKey = firstOrDefault(null, chatModelConfig.apiKey(), azureAiConfig.apiKey()); + var adToken = firstOrDefault(null, chatModelConfig.adToken(), azureAiConfig.adToken()); var builder = AzureOpenAiChatModel.builder() .endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT)) .configName(NamedConfigUtil.isDefault(configName) ? null : configName) .apiKey(apiKey) .adToken(adToken) // .tokenizer(new OpenAiTokenizer("")) TODO: Set the tokenizer, it is always null!! - .apiVersion(azureAiConfig.apiVersion()) + .apiVersion(chatModelConfig.apiVersion().orElse(azureAiConfig.apiVersion())) .timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10))) .maxRetries(azureAiConfig.maxRetries()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests())) @@ -158,15 +156,15 @@ public Function, EmbeddingModel> embe LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName); if (azureAiConfig.enableIntegration()) { - EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel(); - String apiKey = azureAiConfig.apiKey().orElse(null); - String adToken = azureAiConfig.adToken().orElse(null); + var embeddingModelConfig = azureAiConfig.embeddingModel(); + var apiKey = firstOrDefault(null, embeddingModelConfig.apiKey(), azureAiConfig.apiKey()); + var adToken = firstOrDefault(null, embeddingModelConfig.adToken(), azureAiConfig.adToken()); var builder = AzureOpenAiEmbeddingModel.builder() .endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING)) .apiKey(apiKey) .adToken(adToken) .configName(NamedConfigUtil.isDefault(configName) ? null : configName) - .apiVersion(azureAiConfig.apiVersion()) + .apiVersion(embeddingModelConfig.apiVersion().orElse(azureAiConfig.apiVersion())) .timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10))) .maxRetries(azureAiConfig.maxRetries()) .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests())) @@ -195,15 +193,14 @@ public Function, ImageModel> imageModel(L LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName); if (azureAiConfig.enableIntegration()) { - var apiKey = azureAiConfig.apiKey().orElse(null); - String adToken = azureAiConfig.adToken().orElse(null); - var imageModelConfig = azureAiConfig.imageModel(); + var apiKey = firstOrDefault(null, imageModelConfig.apiKey(), azureAiConfig.apiKey()); + var adToken = firstOrDefault(null, imageModelConfig.adToken(), azureAiConfig.adToken()); var builder = AzureOpenAiImageModel.builder() .endpoint(getEndpoint(azureAiConfig, configName, EndpointType.IMAGE)) .apiKey(apiKey) .adToken(adToken) - .apiVersion(azureAiConfig.apiVersion()) + .apiVersion(imageModelConfig.apiVersion().orElse(azureAiConfig.apiVersion())) .timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10))) .maxRetries(azureAiConfig.maxRetries()) .logRequests(firstOrDefault(false, imageModelConfig.logRequests(), azureAiConfig.logRequests())) diff --git a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java index b7ad8266f..99debe03d 100644 --- a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java +++ b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java @@ -36,6 +36,23 @@ public interface ChatModelConfig { @WithDefault(ConfigConstants.DUMMY_VALUE) Optional endpoint(); + /** + * The Azure AD token to use for this operation. + * If present, then the requests towards OpenAI will include this in the Authorization header. + * Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}. + */ + Optional adToken(); + + /** + * The API version to use for this operation. This follows the YYYY-MM-DD format + */ + Optional apiVersion(); + + /** + * Azure OpenAI API key + */ + Optional apiKey(); + /** * What sampling temperature to use, with values between 0 and 2. * Higher values means the model will take more risks. diff --git a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java index c0251045b..f3c925f14 100644 --- a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java +++ b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java @@ -31,6 +31,23 @@ public interface EmbeddingModelConfig { */ Optional endpoint(); + /** + * The Azure AD token to use for this operation. + * If present, then the requests towards OpenAI will include this in the Authorization header. + * Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}. + */ + Optional adToken(); + + /** + * The API version to use for this operation. This follows the YYYY-MM-DD format + */ + Optional apiVersion(); + + /** + * Azure OpenAI API key + */ + Optional apiKey(); + /** * Whether embedding model requests should be logged */ diff --git a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ImageModelConfig.java b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ImageModelConfig.java index a7f92bd1d..716d1aff4 100644 --- a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ImageModelConfig.java +++ b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ImageModelConfig.java @@ -33,6 +33,23 @@ public interface ImageModelConfig { */ Optional endpoint(); + /** + * The Azure AD token to use for this operation. + * If present, then the requests towards OpenAI will include this in the Authorization header. + * Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}. + */ + Optional adToken(); + + /** + * The API version to use for this operation. This follows the YYYY-MM-DD format + */ + Optional apiVersion(); + + /** + * Azure OpenAI API key + */ + Optional apiKey(); + /** * Model name to use */ diff --git a/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java b/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java index faa3a4907..a5dbf4d3a 100644 --- a/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java +++ b/model-providers/openai/azure-openai/runtime/src/test/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorderEndpointTests.java @@ -219,6 +219,21 @@ public Optional endpoint() { return Optional.empty(); } + @Override + public Optional adToken() { + return Optional.empty(); + } + + @Override + public Optional apiVersion() { + return Optional.empty(); + } + + @Override + public Optional apiKey() { + return Optional.empty(); + } + @Override public Double temperature() { return null; @@ -285,6 +300,21 @@ public Optional endpoint() { return Optional.empty(); } + @Override + public Optional adToken() { + return Optional.empty(); + } + + @Override + public Optional apiVersion() { + return Optional.empty(); + } + + @Override + public Optional apiKey() { + return Optional.empty(); + } + @Override public Optional logRequests() { return Optional.empty(); @@ -294,6 +324,7 @@ public Optional logRequests() { public Optional logResponses() { return Optional.empty(); } + }; } @@ -320,6 +351,21 @@ public Optional endpoint() { return Optional.empty(); } + @Override + public Optional adToken() { + return Optional.empty(); + } + + @Override + public Optional apiVersion() { + return Optional.empty(); + } + + @Override + public Optional apiKey() { + return Optional.empty(); + } + @Override public String modelName() { return null;