From d0499663a03448511eb7a5e5b369d6e9fdff9e2e Mon Sep 17 00:00:00 2001 From: "Mark J. Hoy" Date: Fri, 5 Jul 2024 13:01:29 -0400 Subject: [PATCH] [Inference API] Add Amazon Bedrock support to Inference API (#110248) * Initial commit; setup Gradle; start service * initial commit * minor cleanups, builds green; needs tests * bug fixes; tested working embeddings & completion * use custom json builder for embeddings request * Ensure auto-close; fix forbidden API * start of adding unit tests; abstraction layers * adding additional tests; cleanups * add requests unit tests * all tests created * fix cohere embeddings response * fix cohere embeddings response * fix lint * better test coverage for secrets; inference client * update thread-safe syncs; make dims/tokens + int * add tests for dims and max tokens positive integer * use requireNonNull;override settings type;cleanups * use r/w lock for client cache * remove client reference counting * update locking in cache; client errors; noop doc * remove extra block in internalGetOrCreateClient * remove duplicate dependencies; cleanup * add fxn to get default embeddings similarity * use async calls to Amazon Bedrock; cleanups * use Clock in cache; simplify locking; cleanups * cleanups around executor; remove some instanceof * cleanups; use EmbeddingRequestChunker * move max chunk size to constants * oof - swapped transport vers w/ master node req * use XContent instead of Jackson JsonFactory * remove gradle versions; do not allow dimensions --- gradle/verification-metadata.xml | 5 + .../org/elasticsearch/TransportVersions.java | 1 + x-pack/plugin/inference/build.gradle | 29 +- .../licenses/aws-java-sdk-LICENSE.txt | 63 + .../licenses/aws-java-sdk-NOTICE.txt | 15 + .../inference/licenses/jaxb-LICENSE.txt | 274 ++++ .../plugin/inference/licenses/jaxb-NOTICE.txt | 1 + .../inference/licenses/joda-time-LICENSE.txt | 202 +++ .../inference/licenses/joda-time-NOTICE.txt | 5 + .../inference/src/main/java/module-info.java | 5 + .../InferenceNamedWriteablesProvider.java | 40 + .../xpack/inference/InferencePlugin.java | 7 + .../AmazonBedrockActionCreator.java | 56 + .../AmazonBedrockActionVisitor.java | 20 + .../AmazonBedrockChatCompletionAction.java | 47 + .../AmazonBedrockEmbeddingsAction.java | 48 + .../AmazonBedrockBaseClient.java | 37 + .../AmazonBedrockChatCompletionExecutor.java | 43 + .../amazonbedrock/AmazonBedrockClient.java | 29 + .../AmazonBedrockClientCache.java | 19 + .../AmazonBedrockEmbeddingsExecutor.java | 44 + ...AmazonBedrockExecuteOnlyRequestSender.java | 124 ++ .../amazonbedrock/AmazonBedrockExecutor.java | 68 + .../AmazonBedrockInferenceClient.java | 166 +++ .../AmazonBedrockInferenceClientCache.java | 137 ++ .../AmazonBedrockRequestSender.java | 126 ++ ...onBedrockChatCompletionRequestManager.java | 65 + ...AmazonBedrockEmbeddingsRequestManager.java | 74 ++ .../AmazonBedrockRequestExecutorService.java | 42 + .../sender/AmazonBedrockRequestManager.java | 54 + .../AmazonBedrockJsonBuilder.java | 30 + .../AmazonBedrockJsonWriter.java | 20 + .../amazonbedrock/AmazonBedrockRequest.java | 85 ++ .../amazonbedrock/NoOpHttpRequest.java | 20 + ...edrockAI21LabsCompletionRequestEntity.java | 63 + ...drockAnthropicCompletionRequestEntity.java | 70 + ...zonBedrockChatCompletionEntityFactory.java | 78 ++ .../AmazonBedrockChatCompletionRequest.java | 69 + ...nBedrockCohereCompletionRequestEntity.java | 70 + .../AmazonBedrockConverseRequestEntity.java | 18 + .../AmazonBedrockConverseUtils.java | 29 + ...zonBedrockMetaCompletionRequestEntity.java | 63 + ...BedrockMistralCompletionRequestEntity.java | 70 + ...onBedrockTitanCompletionRequestEntity.java | 63 + ...nBedrockCohereEmbeddingsRequestEntity.java | 35 + .../AmazonBedrockEmbeddingsEntityFactory.java | 45 + .../AmazonBedrockEmbeddingsRequest.java | 99 ++ ...onBedrockTitanEmbeddingsRequestEntity.java | 31 + .../amazonbedrock/AmazonBedrockResponse.java | 15 + .../AmazonBedrockResponseHandler.java | 23 + .../AmazonBedrockResponseListener.java | 30 + .../AmazonBedrockChatCompletionResponse.java | 49 + ...nBedrockChatCompletionResponseHandler.java | 39 + ...BedrockChatCompletionResponseListener.java | 40 + .../AmazonBedrockEmbeddingsResponse.java | 132 ++ ...mazonBedrockEmbeddingsResponseHandler.java | 37 + ...azonBedrockEmbeddingsResponseListener.java | 38 + .../amazonbedrock/AmazonBedrockConstants.java | 27 + .../amazonbedrock/AmazonBedrockModel.java | 88 ++ .../amazonbedrock/AmazonBedrockProvider.java | 30 + .../AmazonBedrockProviderCapabilities.java | 102 ++ .../AmazonBedrockSecretSettings.java | 110 ++ .../amazonbedrock/AmazonBedrockService.java | 350 +++++ .../AmazonBedrockServiceSettings.java | 141 ++ .../AmazonBedrockChatCompletionModel.java | 83 ++ ...rockChatCompletionRequestTaskSettings.java | 90 ++ ...nBedrockChatCompletionServiceSettings.java | 93 ++ ...azonBedrockChatCompletionTaskSettings.java | 190 +++ .../AmazonBedrockEmbeddingsModel.java | 85 ++ ...mazonBedrockEmbeddingsServiceSettings.java | 220 ++++ .../plugin-metadata/plugin-security.policy | 8 +- .../AmazonBedrockActionCreatorTests.java | 175 +++ .../AmazonBedrockExecutorTests.java | 172 +++ ...mazonBedrockInferenceClientCacheTests.java | 108 ++ .../AmazonBedrockMockClientCache.java | 62 + ...AmazonBedrockMockExecuteRequestSender.java | 80 ++ .../AmazonBedrockMockInferenceClient.java | 133 ++ .../AmazonBedrockMockRequestSender.java | 91 ++ .../AmazonBedrockRequestSenderTests.java | 127 ++ ...kAI21LabsCompletionRequestEntityTests.java | 70 + ...AnthropicCompletionRequestEntityTests.java | 82 ++ ...ockCohereCompletionRequestEntityTests.java | 82 ++ .../AmazonBedrockConverseRequestUtils.java | 94 ++ ...drockMetaCompletionRequestEntityTests.java | 70 + ...ckMistralCompletionRequestEntityTests.java | 82 ++ ...rockTitanCompletionRequestEntityTests.java | 70 + ...ockCohereEmbeddingsRequestEntityTests.java | 25 + ...rockTitanEmbeddingsRequestEntityTests.java | 24 + .../AmazonBedrockSecretSettingsTests.java | 120 ++ .../AmazonBedrockServiceTests.java | 1131 +++++++++++++++++ ...AmazonBedrockChatCompletionModelTests.java | 221 ++++ ...hatCompletionRequestTaskSettingsTests.java | 107 ++ ...ockChatCompletionServiceSettingsTests.java | 131 ++ ...edrockChatCompletionTaskSettingsTests.java | 226 ++++ .../AmazonBedrockEmbeddingsModelTests.java | 81 ++ ...BedrockEmbeddingsServiceSettingsTests.java | 404 ++++++ 96 files changed, 8790 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt create mode 100644 x-pack/plugin/inference/licenses/jaxb-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/jaxb-NOTICE.txt create mode 100644 x-pack/plugin/inference/licenses/joda-time-LICENSE.txt create mode 100644 x-pack/plugin/inference/licenses/joda-time-NOTICE.txt create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index cd408ba75aa10..02313c5ed82a2 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -84,6 +84,11 @@ + + + + + diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2004c6fda8ce5..ff50d1513d28a 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -208,6 +208,7 @@ static TransportVersion def(int id) { public static final TransportVersion TEXT_SIMILARITY_RERANKER_RETRIEVER = def(8_699_00_0); public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0); public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0); + public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 41ca9966c1336..beeec94f21ebf 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -27,6 +27,10 @@ base { archivesName = 'x-pack-inference' } +versions << [ + 'awsbedrockruntime': '1.12.740' +] + dependencies { implementation project(path: ':libs:elasticsearch-logging') compileOnly project(":server") @@ -53,10 +57,19 @@ dependencies { implementation 'com.google.http-client:google-http-client-appengine:1.42.3' implementation 'com.google.http-client:google-http-client-jackson2:1.42.3' implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" + implementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}" + implementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" + implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" + implementation "com.fasterxml.jackson:jackson-bom:${versions.jackson}" implementation 'com.google.api:gax-httpjson:0.105.1' implementation 'io.grpc:grpc-context:1.49.2' implementation 'io.opencensus:opencensus-api:0.31.1' implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1' + implementation "com.amazonaws:aws-java-sdk-bedrockruntime:${versions.awsbedrockruntime}" + implementation "com.amazonaws:aws-java-sdk-core:${versions.aws}" + implementation "com.amazonaws:jmespath-java:${versions.aws}" + implementation "joda-time:joda-time:2.10.10" + implementation 'javax.xml.bind:jaxb-api:2.2.2' } tasks.named("dependencyLicenses").configure { @@ -66,6 +79,9 @@ tasks.named("dependencyLicenses").configure { mapping from: /protobuf.*/, to: 'protobuf' mapping from: /proto-google.*/, to: 'proto-google' mapping from: /jackson.*/, to: 'jackson' + mapping from: /aws-java-sdk-.*/, to: 'aws-java-sdk' + mapping from: /jmespath-java.*/, to: 'aws-java-sdk' + mapping from: /jaxb-.*/, to: 'jaxb' } tasks.named("thirdPartyAudit").configure { @@ -199,10 +215,21 @@ tasks.named("thirdPartyAudit").configure { 'com.google.appengine.api.urlfetch.HTTPRequest', 'com.google.appengine.api.urlfetch.HTTPResponse', 'com.google.appengine.api.urlfetch.URLFetchService', - 'com.google.appengine.api.urlfetch.URLFetchServiceFactory' + 'com.google.appengine.api.urlfetch.URLFetchServiceFactory', + 'software.amazon.ion.IonReader', + 'software.amazon.ion.IonSystem', + 'software.amazon.ion.IonType', + 'software.amazon.ion.IonWriter', + 'software.amazon.ion.Timestamp', + 'software.amazon.ion.system.IonBinaryWriterBuilder', + 'software.amazon.ion.system.IonSystemBuilder', + 'software.amazon.ion.system.IonTextWriterBuilder', + 'software.amazon.ion.system.IonWriterBuilder', + 'javax.activation.DataHandler' ) } tasks.named('yamlRestTest') { usesDefaultDistribution() } + diff --git a/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt b/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt new file mode 100644 index 0000000000000..98d1f9319f374 --- /dev/null +++ b/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt @@ -0,0 +1,63 @@ +Apache License +Version 2.0, January 2004 + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + 1. You must give any other recipients of the Work or Derivative Works a copy of this License; and + 2. You must cause any modified files to carry prominent notices stating that You changed the files; and + 3. You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + 4. If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +Note: Other license terms may apply to certain, identified software files contained within or distributed with the accompanying software if such terms are included in the directory containing the accompanying software. Such other license terms will then apply in lieu of the terms of the software license above. + +JSON processing code subject to the JSON License from JSON.org: + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +The Software shall be used for Good, not Evil. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt b/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt new file mode 100644 index 0000000000000..565bd6085c71a --- /dev/null +++ b/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt @@ -0,0 +1,15 @@ +AWS SDK for Java +Copyright 2010-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +This product includes software developed by +Amazon Technologies, Inc (http://www.amazon.com/). + +********************** +THIRD PARTY COMPONENTS +********************** +This software includes third party software subject to the following copyrights: +- XML parsing and utility functions from JetS3t - Copyright 2006-2009 James Murty. +- JSON parsing and utility functions from JSON.org - Copyright 2002 JSON.org. +- PKCS#1 PEM encoded private key parsing and utility functions from oauth.googlecode.com - Copyright 1998-2010 AOL Inc. + +The licenses for these third party components are included in LICENSE.txt diff --git a/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt b/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt new file mode 100644 index 0000000000000..833a843cfeee1 --- /dev/null +++ b/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt @@ -0,0 +1,274 @@ +COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL)Version 1.1 + +1. Definitions. + + 1.1. "Contributor" means each individual or entity that creates or contributes to the creation of Modifications. + + 1.2. "Contributor Version" means the combination of the Original Software, prior Modifications used by a Contributor (if any), and the Modifications made by that particular Contributor. + + 1.3. "Covered Software" means (a) the Original Software, or (b) Modifications, or (c) the combination of files containing Original Software with files containing Modifications, in each case including portions thereof. + + 1.4. "Executable" means the Covered Software in any form other than Source Code. + + 1.5. "Initial Developer" means the individual or entity that first makes Original Software available under this License. + + 1.6. "Larger Work" means a work which combines Covered Software or portions thereof with code not governed by the terms of this License. + + 1.7. "License" means this document. + + 1.8. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently acquired, any and all of the rights conveyed herein. + + 1.9. "Modifications" means the Source Code and Executable form of any of the following: + + A. Any file that results from an addition to, deletion from or modification of the contents of a file containing Original Software or previous Modifications; + + B. Any new file that contains any part of the Original Software or previous Modification; or + + C. Any new file that is contributed or otherwise made available under the terms of this License. + + 1.10. "Original Software" means the Source Code and Executable form of computer software code that is originally released under this License. + + 1.11. "Patent Claims" means any patent claim(s), now owned or hereafter acquired, including without limitation, method, process, and apparatus claims, in any patent Licensable by grantor. + + 1.12. "Source Code" means (a) the common form of computer software code in which modifications are made and (b) associated documentation included in or with such code. + + 1.13. "You" (or "Your") means an individual or a legal entity exercising rights under, and complying with all of the terms of, this License. For legal entities, "You" includes any entity which controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. + +2. License Grants. + + 2.1. The Initial Developer Grant. + + Conditioned upon Your compliance with Section 3.1 below and subject to third party intellectual property claims, the Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive license: + + (a) under intellectual property rights (other than patent or trademark) Licensable by Initial Developer, to use, reproduce, modify, display, perform, sublicense and distribute the Original Software (or portions thereof), with or without Modifications, and/or as part of a Larger Work; and + + (b) under Patent Claims infringed by the making, using or selling of Original Software, to make, have made, use, practice, sell, and offer for sale, and/or otherwise dispose of the Original Software (or portions thereof). + + (c) The licenses granted in Sections 2.1(a) and (b) are effective on the date Initial Developer first distributes or otherwise makes the Original Software available to a third party under the terms of this License. + + (d) Notwithstanding Section 2.1(b) above, no patent license is granted: (1) for code that You delete from the Original Software, or (2) for infringements caused by: (i) the modification of the Original Software, or (ii) the combination of the Original Software with other software or devices. + + 2.2. Contributor Grant. + + Conditioned upon Your compliance with Section 3.1 below and subject to third party intellectual property claims, each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: + + (a) under intellectual property rights (other than patent or trademark) Licensable by Contributor to use, reproduce, modify, display, perform, sublicense and distribute the Modifications created by such Contributor (or portions thereof), either on an unmodified basis, with other Modifications, as Covered Software and/or as part of a Larger Work; and + + (b) under Patent Claims infringed by the making, using, or selling of Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such combination), to make, use, sell, offer for sale, have made, and/or otherwise dispose of: (1) Modifications made by that Contributor (or portions thereof); and (2) the combination of Modifications made by that Contributor with its Contributor Version (or portions of such combination). + + (c) The licenses granted in Sections 2.2(a) and 2.2(b) are effective on the date Contributor first distributes or otherwise makes the Modifications available to a third party. + + (d) Notwithstanding Section 2.2(b) above, no patent license is granted: (1) for any code that Contributor has deleted from the Contributor Version; (2) for infringements caused by: (i) third party modifications of Contributor Version, or (ii) the combination of Modifications made by that Contributor with other software (except as part of the Contributor Version) or other devices; or (3) under Patent Claims infringed by Covered Software in the absence of Modifications made by that Contributor. + +3. Distribution Obligations. + + 3.1. Availability of Source Code. + + Any Covered Software that You distribute or otherwise make available in Executable form must also be made available in Source Code form and that Source Code form must be distributed only under the terms of this License. You must include a copy of this License with every copy of the Source Code form of the Covered Software You distribute or otherwise make available. You must inform recipients of any such Covered Software in Executable form as to how they can obtain such Covered Software in Source Code form in a reasonable manner on or through a medium customarily used for software exchange. + + 3.2. Modifications. + + The Modifications that You create or to which You contribute are governed by the terms of this License. You represent that You believe Your Modifications are Your original creation(s) and/or You have sufficient rights to grant the rights conveyed by this License. + + 3.3. Required Notices. + + You must include a notice in each of Your Modifications that identifies You as the Contributor of the Modification. You may not remove or alter any copyright, patent or trademark notices contained within the Covered Software, or any notices of licensing or any descriptive text giving attribution to any Contributor or the Initial Developer. + + 3.4. Application of Additional Terms. + + You may not offer or impose any terms on any Covered Software in Source Code form that alters or restricts the applicable version of this License or the recipients' rights hereunder. You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, you may do so only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You must make it absolutely clear that any such warranty, support, indemnity or liability obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer and every Contributor for any liability incurred by the Initial Developer or such Contributor as a result of warranty, support, indemnity or liability terms You offer. + + 3.5. Distribution of Executable Versions. + + You may distribute the Executable form of the Covered Software under the terms of this License or under the terms of a license of Your choice, which may contain terms different from this License, provided that You are in compliance with the terms of this License and that the license for the Executable form does not attempt to limit or alter the recipient's rights in the Source Code form from the rights set forth in this License. If You distribute the Covered Software in Executable form under a different license, You must make it absolutely clear that any terms which differ from this License are offered by You alone, not by the Initial Developer or Contributor. You hereby agree to indemnify the Initial Developer and every Contributor for any liability incurred by the Initial Developer or such Contributor as a result of any such terms You offer. + + 3.6. Larger Works. + + You may create a Larger Work by combining Covered Software with other code not governed by the terms of this License and distribute the Larger Work as a single product. In such a case, You must make sure the requirements of this License are fulfilled for the Covered Software. + +4. Versions of the License. + + 4.1. New Versions. + + Oracle is the initial license steward and may publish revised and/or new versions of this License from time to time. Each version will be given a distinguishing version number. Except as provided in Section 4.3, no one other than the license steward has the right to modify this License. + + 4.2. Effect of New Versions. + + You may always continue to use, distribute or otherwise make the Covered Software available under the terms of the version of the License under which You originally received the Covered Software. If the Initial Developer includes a notice in the Original Software prohibiting it from being distributed or otherwise made available under any subsequent version of the License, You must distribute and make the Covered Software available under the terms of the version of the License under which You originally received the Covered Software. Otherwise, You may also choose to use, distribute or otherwise make the Covered Software available under the terms of any subsequent version of the License published by the license steward. + + 4.3. Modified Versions. + + When You are an Initial Developer and You want to create a new license for Your Original Software, You may create and use a modified version of this License if You: (a) rename the license and remove any references to the name of the license steward (except to note that the license differs from this License); and (b) otherwise make it clear that the license contains terms which differ from this License. + +5. DISCLAIMER OF WARRANTY. + + COVERED SOFTWARE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED SOFTWARE IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE COVERED SOFTWARE IS WITH YOU. SHOULD ANY COVERED SOFTWARE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE OF ANY COVERED SOFTWARE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS DISCLAIMER. + +6. TERMINATION. + + 6.1. This License and the rights granted hereunder will terminate automatically if You fail to comply with terms herein and fail to cure such breach within 30 days of becoming aware of the breach. Provisions which, by their nature, must remain in effect beyond the termination of this License shall survive. + + 6.2. If You assert a patent infringement claim (excluding declaratory judgment actions) against Initial Developer or a Contributor (the Initial Developer or Contributor against whom You assert such claim is referred to as "Participant") alleging that the Participant Software (meaning the Contributor Version where the Participant is a Contributor or the Original Software where the Participant is the Initial Developer) directly or indirectly infringes any patent, then any and all rights granted directly or indirectly to You by such Participant, the Initial Developer (if the Initial Developer is not the Participant) and all Contributors under Sections 2.1 and/or 2.2 of this License shall, upon 60 days notice from Participant terminate prospectively and automatically at the expiration of such 60 day notice period, unless if within such 60 day period You withdraw Your claim with respect to the Participant Software against such Participant either unilaterally or pursuant to a written agreement with Participant. + + 6.3. If You assert a patent infringement claim against Participant alleging that the Participant Software directly or indirectly infringes any patent where such claim is resolved (such as by license or settlement) prior to the initiation of patent infringement litigation, then the reasonable value of the licenses granted by such Participant under Sections 2.1 or 2.2 shall be taken into account in determining the amount or value of any payment or license. + + 6.4. In the event of termination under Sections 6.1 or 6.2 above, all end user licenses that have been validly granted by You or any distributor hereunder prior to termination (excluding licenses granted to You by any distributor) shall survive termination. + +7. LIMITATION OF LIABILITY. + + UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED SOFTWARE, OR ANY SUPPLIER OF ANY OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND LIMITATION MAY NOT APPLY TO YOU. + +8. U.S. GOVERNMENT END USERS. + + The Covered Software is a "commercial item," as that term is defined in 48 C.F.R. 2.101 (Oct. 1995), consisting of "commercial computer software" (as that term is defined at 48 C.F.R. ? 252.227-7014(a)(1)) and "commercial computer software documentation" as such terms are used in 48 C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users acquire Covered Software with only those rights set forth herein. This U.S. Government Rights clause is in lieu of, and supersedes, any other FAR, DFAR, or other clause or provision that addresses Government rights in computer software under this License. + +9. MISCELLANEOUS. + + This License represents the complete agreement concerning subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. This License shall be governed by the law of the jurisdiction specified in a notice contained within the Original Software (except to the extent applicable law, if any, provides otherwise), excluding such jurisdiction's conflict-of-law provisions. Any litigation relating to this License shall be subject to the jurisdiction of the courts located in the jurisdiction and venue specified in a notice contained within the Original Software, with the losing party responsible for costs, including, without limitation, court costs and reasonable attorneys' fees and expenses. The application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not apply to this License. You agree that You alone are responsible for compliance with the United States export administration regulations (and the export control laws and regulation of any other countries) when You use, distribute or otherwise make available any Covered Software. + +10. RESPONSIBILITY FOR CLAIMS. + + As between Initial Developer and the Contributors, each party is responsible for claims and damages arising, directly or indirectly, out of its utilization of rights under this License and You agree to work with Initial Developer and Contributors to distribute such responsibility on an equitable basis. Nothing herein is intended or shall be deemed to constitute any admission of liability. + +---------- +NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) +The code released under the CDDL shall be governed by the laws of the State of California (excluding conflict-of-law provisions). Any litigation relating to this License shall be subject to the jurisdiction of the Federal Courts of the Northern District of California and the state courts of the State of California, with venue lying in Santa Clara County, California. + + + + +The GNU General Public License (GPL) Version 2, June 1991 + + +Copyright (C) 1989, 1991 Free Software Foundation, Inc. 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + +Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. + +Preamble + +The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. This General Public License applies to most of the Free Software Foundation's software and to any other program whose authors commit to using it. (Some other Free Software Foundation software is covered by the GNU Library General Public License instead.) You can apply it to your programs, too. + +When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs; and that you know you can do these things. + +To protect your rights, we need to make restrictions that forbid anyone to deny you these rights or to ask you to surrender the rights. These restrictions translate to certain responsibilities for you if you distribute copies of the software, or if you modify it. + +For example, if you distribute copies of such a program, whether gratis or for a fee, you must give the recipients all the rights that you have. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. + +We protect your rights with two steps: (1) copyright the software, and (2) offer you this license which gives you legal permission to copy, distribute and/or modify the software. + +Also, for each author's protection and ours, we want to make certain that everyone understands that there is no warranty for this free software. If the software is modified by someone else and passed on, we want its recipients to know that what they have is not the original, so that any problems introduced by others will not reflect on the original authors' reputations. + +Finally, any free program is threatened constantly by software patents. We wish to avoid the danger that redistributors of a free program will individually obtain patent licenses, in effect making the program proprietary. To prevent this, we have made it clear that any patent must be licensed for everyone's free use or not licensed at all. + +The precise terms and conditions for copying, distribution and modification follow. + + +TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + +0. This License applies to any program or other work which contains a notice placed by the copyright holder saying it may be distributed under the terms of this General Public License. The "Program", below, refers to any such program or work, and a "work based on the Program" means either the Program or any derivative work under copyright law: that is to say, a work containing the Program or a portion of it, either verbatim or with modifications and/or translated into another language. (Hereinafter, translation is included without limitation in the term "modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running the Program is not restricted, and the output from the Program is covered only if its contents constitute a work based on the Program (independent of having been made by running the Program). Whether that is true depends on what the Program does. + +1. You may copy and distribute verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and give any other recipients of the Program a copy of this License along with the Program. + +You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. + +2. You may modify your copy or copies of the Program or any portion of it, thus forming a work based on the Program, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in whole or in part contains or is derived from the Program or any part thereof, to be licensed as a whole at no charge to all third parties under the terms of this License. + + c) If the modified program normally reads commands interactively when run, you must cause it, when started running for such interactive use in the most ordinary way, to print or display an announcement including an appropriate copyright notice and a notice that there is no warranty (or else, saying that you provide a warranty) and that users may redistribute the program under these conditions, and telling the user how to view a copy of this License. (Exception: if the Program itself is interactive but does not normally print such an announcement, your work based on the Program is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Program, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Program, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program with the Program (or with a work based on the Program) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. + +3. You may copy and distribute the Program (or a work based on it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three years, to give any third party, for a charge no more than your cost of physically performing source distribution, a complete machine-readable copy of the corresponding source code, to be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer to distribute corresponding source code. (This alternative is allowed only for noncommercial distribution and only if you received the program in object code or executable form with such an offer, in accord with Subsection b above.) + +The source code for a work means the preferred form of the work for making modifications to it. For an executable work, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the executable. However, as a special exception, the source code distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. + +If distribution of executable or object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place counts as distribution of the source code, even though third parties are not compelled to copy the source along with the object code. + +4. You may not copy, modify, sublicense, or distribute the Program except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense or distribute the Program is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. + +5. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Program or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Program (or any work based on the Program), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Program or works based on it. + +6. Each time you redistribute the Program (or any work based on the Program), the recipient automatically receives a license from the original licensor to copy, distribute or modify the Program subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties to this License. + +7. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Program at all. For example, if a patent license would not permit royalty-free redistribution of the Program by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Program. + +If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply and the section as a whole is intended to apply in other circumstances. + +It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system, which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. + +This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. + +8. If the distribution and/or use of the Program is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Program under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. + +9. The Free Software Foundation may publish revised and/or new versions of the General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of this License, you may choose any version ever published by the Free Software Foundation. + +10. If you wish to incorporate parts of the Program into other free programs whose distribution conditions are different, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. + +NO WARRANTY + +11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + +12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +END OF TERMS AND CONDITIONS + + +How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. + +To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. + + One line to give the program's name and a brief idea of what it does. + + Copyright (C) + + This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, the commands you use may be called something other than `show w' and `show c'; they could even be mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the program, if necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the program `Gnomovision' (which makes passes at compilers) written by James Hacker. + + signature of Ty Coon, 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Library General Public License instead of this License. + + +"CLASSPATH" EXCEPTION TO THE GPL VERSION 2 + +Certain source files distributed by Oracle are subject to the following clarification and special exception to the GPL Version 2, but only where Oracle has expressly included in the particular source file's header the words "Oracle designates this particular file as subject to the "Classpath" exception as provided by Oracle in the License file that accompanied this code." + +Linking this library statically or dynamically with other modules is making a combined work based on this library. Thus, the terms and conditions of the GNU General Public License Version 2 cover the whole combination. + +As a special exception, the copyright holders of this library give you permission to link this library with independent modules to produce an executable, regardless of the license terms of these independent modules, and to copy and distribute the resulting executable under terms of your choice, provided that you also meet, for each linked independent module, the terms and conditions of the license of that module. An independent module is a module which is not derived from or based on this library. If you modify this library, you may extend this exception to your version of the library, but you are not obligated to do so. If you do not wish to do so, delete this exception statement from your version. diff --git a/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt b/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt new file mode 100644 index 0000000000000..8d1c8b69c3fce --- /dev/null +++ b/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt @@ -0,0 +1 @@ + diff --git a/x-pack/plugin/inference/licenses/joda-time-LICENSE.txt b/x-pack/plugin/inference/licenses/joda-time-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/inference/licenses/joda-time-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt b/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt new file mode 100644 index 0000000000000..dffbcf31cacf6 --- /dev/null +++ b/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt @@ -0,0 +1,5 @@ +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index aa907a236884a..a7e5718a0920e 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -20,8 +20,13 @@ requires org.apache.lucene.join; requires com.ibm.icu; requires com.google.auth.oauth2; + requires com.google.auth; requires com.google.api.client; requires com.google.gson; + requires aws.java.sdk.bedrockruntime; + requires aws.java.sdk.core; + requires com.fasterxml.jackson.databind; + requires org.joda.time; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f3799b824fc0e..f8ce9df1fb194 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -24,6 +24,10 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings; @@ -122,10 +126,46 @@ public static List getNamedWriteables() { addMistralNamedWriteables(namedWriteables); addCustomElandWriteables(namedWriteables); addAnthropicNamedWritables(namedWriteables); + addAmazonBedrockNamedWriteables(namedWriteables); return namedWriteables; } + private static void addAmazonBedrockNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + AmazonBedrockSecretSettings.class, + AmazonBedrockSecretSettings.NAME, + AmazonBedrockSecretSettings::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AmazonBedrockEmbeddingsServiceSettings.NAME, + AmazonBedrockEmbeddingsServiceSettings::new + ) + ); + + // no task settings for Amazon Bedrock Embeddings + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AmazonBedrockChatCompletionServiceSettings.NAME, + AmazonBedrockChatCompletionServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AmazonBedrockChatCompletionTaskSettings.NAME, + AmazonBedrockChatCompletionTaskSettings::new + ) + ); + } + private static void addMistralNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1db5b4135ee94..1c388f7399260 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -53,6 +53,7 @@ import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; @@ -70,6 +71,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService; import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService; import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; @@ -117,6 +119,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final Settings settings; private final SetOnce httpFactory = new SetOnce<>(); + private final SetOnce amazonBedrockFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); @@ -170,6 +173,9 @@ public Collection createComponents(PluginServices services) { var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService()); httpFactory.set(httpRequestSenderFactory); + var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); + amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); + ModelRegistry modelRegistry = new ModelRegistry(services.client()); if (inferenceServiceExtensions == null) { @@ -209,6 +215,7 @@ public List getInferenceServiceFactories() { context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), context -> new MistralService(httpFactory.get(), serviceComponents.get()), context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java new file mode 100644 index 0000000000000..5f9fc532e33b2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +public class AmazonBedrockActionCreator implements AmazonBedrockActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + private final TimeValue timeout; + + public AmazonBedrockActionCreator(Sender sender, ServiceComponents serviceComponents, @Nullable TimeValue timeout) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.timeout = timeout; + } + + @Override + public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings) { + var overriddenModel = AmazonBedrockEmbeddingsModel.of(embeddingsModel, taskSettings); + var requestManager = new AmazonBedrockEmbeddingsRequestManager( + overriddenModel, + serviceComponents.truncator(), + serviceComponents.threadPool(), + timeout + ); + var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock embeddings"); + return new AmazonBedrockEmbeddingsAction(sender, requestManager, errorMessage); + } + + @Override + public ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map taskSettings) { + var overriddenModel = AmazonBedrockChatCompletionModel.of(completionModel, taskSettings); + var requestManager = new AmazonBedrockChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool(), timeout); + var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock completion"); + return new AmazonBedrockChatCompletionAction(sender, requestManager, errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java new file mode 100644 index 0000000000000..b540d030eb3f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.Map; + +public interface AmazonBedrockActionVisitor { + ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings); + + ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java new file mode 100644 index 0000000000000..9d3c39d3ac4d9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AmazonBedrockChatCompletionAction implements ExecutableAction { + private final Sender sender; + private final RequestManager requestManager; + private final String errorMessage; + + public AmazonBedrockChatCompletionAction(Sender sender, RequestManager requestManager, String errorMessage) { + this.sender = Objects.requireNonNull(sender); + this.requestManager = Objects.requireNonNull(requestManager); + this.errorMessage = Objects.requireNonNull(errorMessage); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java new file mode 100644 index 0000000000000..3f8be0c3cccbe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AmazonBedrockEmbeddingsAction implements ExecutableAction { + + private final Sender sender; + private final RequestManager requestManager; + private final String errorMessage; + + public AmazonBedrockEmbeddingsAction(Sender sender, RequestManager requestManager, String errorMessage) { + this.sender = Objects.requireNonNull(sender); + this.requestManager = Objects.requireNonNull(requestManager); + this.errorMessage = Objects.requireNonNull(errorMessage); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java new file mode 100644 index 0000000000000..f9e403582a0ec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.time.Clock; +import java.util.Objects; + +public abstract class AmazonBedrockBaseClient implements AmazonBedrockClient { + protected final Integer modelKeysAndRegionHashcode; + protected Clock clock = Clock.systemUTC(); + + protected AmazonBedrockBaseClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + Objects.requireNonNull(model); + this.modelKeysAndRegionHashcode = getModelKeysAndRegionHashcode(model, timeout); + } + + public static Integer getModelKeysAndRegionHashcode(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var secretSettings = model.getSecretSettings(); + var serviceSettings = model.getServiceSettings(); + return Objects.hash(secretSettings.accessKey, secretSettings.secretKey, serviceSettings.region(), timeout); + } + + public final void setClock(Clock clock) { + this.clock = clock; + } + + abstract void close(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java new file mode 100644 index 0000000000000..a4e0c399517c1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener; + +import java.util.function.Supplier; + +public class AmazonBedrockChatCompletionExecutor extends AmazonBedrockExecutor { + private final AmazonBedrockChatCompletionRequest chatCompletionRequest; + + protected AmazonBedrockChatCompletionExecutor( + AmazonBedrockChatCompletionRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache); + this.chatCompletionRequest = request; + } + + @Override + protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) { + var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener( + chatCompletionRequest, + responseHandler, + inferenceResultsListener + ); + chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java new file mode 100644 index 0000000000000..812e76129c420 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; + +import java.time.Instant; + +public interface AmazonBedrockClient { + void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException; + + void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) + throws ElasticsearchException; + + boolean isExpired(Instant currentTimestampMs); + + void resetExpiration(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java new file mode 100644 index 0000000000000..e6bb99620b581 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.io.Closeable; +import java.io.IOException; + +public interface AmazonBedrockClientCache extends Closeable { + AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) throws IOException; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java new file mode 100644 index 0000000000000..6da3f86e0909a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseListener; + +import java.util.function.Supplier; + +public class AmazonBedrockEmbeddingsExecutor extends AmazonBedrockExecutor { + + private final AmazonBedrockEmbeddingsRequest embeddingsRequest; + + protected AmazonBedrockEmbeddingsExecutor( + AmazonBedrockEmbeddingsRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache); + this.embeddingsRequest = request; + } + + @Override + protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) { + var embeddingsResponseListener = new AmazonBedrockEmbeddingsResponseListener( + embeddingsRequest, + responseHandler, + inferenceResultsListener + ); + embeddingsRequest.executeEmbeddingsRequest(awsBedrockClient, embeddingsResponseListener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java new file mode 100644 index 0000000000000..a08acab655936 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import java.io.IOException; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.core.Strings.format; + +/** + * The AWS SDK uses its own internal retrier and timeout values on the client + */ +public class AmazonBedrockExecuteOnlyRequestSender implements RequestSender { + + protected final AmazonBedrockClientCache clientCache; + private final ThrottlerManager throttleManager; + + public AmazonBedrockExecuteOnlyRequestSender(AmazonBedrockClientCache clientCache, ThrottlerManager throttlerManager) { + this.clientCache = Objects.requireNonNull(clientCache); + this.throttleManager = Objects.requireNonNull(throttlerManager); + } + + @Override + public void send( + Logger logger, + Request request, + HttpClientContext context, + Supplier hasRequestTimedOutFunction, + ResponseHandler responseHandler, + ActionListener listener + ) { + if (request instanceof AmazonBedrockRequest awsRequest && responseHandler instanceof AmazonBedrockResponseHandler awsResponse) { + try { + var executor = createExecutor(awsRequest, awsResponse, logger, hasRequestTimedOutFunction, listener); + + // the run method will call the listener to return the proper value + executor.run(); + return; + } catch (Exception e) { + logException(logger, request, e); + listener.onFailure(wrapWithElasticsearchException(e, request.getInferenceEntityId())); + } + } + + listener.onFailure(new ElasticsearchException("Amazon Bedrock request was not the correct type")); + } + + // allow this to be overridden for testing + protected AmazonBedrockExecutor createExecutor( + AmazonBedrockRequest awsRequest, + AmazonBedrockResponseHandler awsResponse, + Logger logger, + Supplier hasRequestTimedOutFunction, + ActionListener listener + ) { + switch (awsRequest.taskType()) { + case COMPLETION -> { + return new AmazonBedrockChatCompletionExecutor( + (AmazonBedrockChatCompletionRequest) awsRequest, + awsResponse, + logger, + hasRequestTimedOutFunction, + listener, + clientCache + ); + } + case TEXT_EMBEDDING -> { + return new AmazonBedrockEmbeddingsExecutor( + (AmazonBedrockEmbeddingsRequest) awsRequest, + awsResponse, + logger, + hasRequestTimedOutFunction, + listener, + clientCache + ); + } + default -> { + throw new UnsupportedOperationException("Unsupported task type [" + awsRequest.taskType() + "] for Amazon Bedrock request"); + } + } + } + + private void logException(Logger logger, Request request, Exception exception) { + var causeException = ExceptionsHelper.unwrapCause(exception); + + throttleManager.warn( + logger, + format("Failed while sending request from inference entity id [%s] of type [amazonbedrock]", request.getInferenceEntityId()), + causeException + ); + } + + private Exception wrapWithElasticsearchException(Exception e, String inferenceEntityId) { + return new ElasticsearchException( + format("Amazon Bedrock client failed to send request from inference entity id [%s]", inferenceEntityId), + e + ); + } + + public void shutdown() throws IOException { + this.clientCache.close(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java new file mode 100644 index 0000000000000..fa220ee5d2831 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public abstract class AmazonBedrockExecutor implements Runnable { + protected final AmazonBedrockModel baseModel; + protected final AmazonBedrockResponseHandler responseHandler; + protected final Logger logger; + protected final AmazonBedrockRequest request; + protected final Supplier hasRequestCompletedFunction; + protected final ActionListener inferenceResultsListener; + protected final AmazonBedrockClientCache clientCache; + + protected AmazonBedrockExecutor( + AmazonBedrockRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + this.request = Objects.requireNonNull(request); + this.responseHandler = Objects.requireNonNull(responseHandler); + this.logger = Objects.requireNonNull(logger); + this.hasRequestCompletedFunction = Objects.requireNonNull(hasRequestCompletedFunction); + this.inferenceResultsListener = Objects.requireNonNull(inferenceResultsListener); + this.clientCache = Objects.requireNonNull(clientCache); + this.baseModel = request.model(); + } + + @Override + public void run() { + if (hasRequestCompletedFunction.get()) { + // has already been run + return; + } + + var inferenceEntityId = baseModel.getInferenceEntityId(); + + try { + var awsBedrockClient = clientCache.getOrCreateClient(baseModel, request.timeout()); + executeClientRequest(awsBedrockClient); + } catch (Exception e) { + var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId); + logger.warn(errorMessage, e); + inferenceResultsListener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } + + protected abstract void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java new file mode 100644 index 0000000000000..c3d458925268c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsync; +import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsyncClientBuilder; +import com.amazonaws.services.bedrockruntime.model.AmazonBedrockRuntimeException; +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.SpecialPermission; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.time.Duration; +import java.time.Instant; +import java.util.Objects; + +/** + * Not marking this as "final" so we can subclass it for mocking + */ +public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient { + + // package-private for testing + static final int CLIENT_CACHE_EXPIRY_MINUTES = 5; + private static final int DEFAULT_CLIENT_TIMEOUT_MS = 10000; + + private final AmazonBedrockRuntimeAsync internalClient; + private volatile Instant expiryTimestamp; + + public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { + try { + return new AmazonBedrockInferenceClient(model, timeout); + } catch (Exception e) { + throw new ElasticsearchException("Failed to create Amazon Bedrock Client", e); + } + } + + protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + super(model, timeout); + this.internalClient = createAmazonBedrockClient(model, timeout); + setExpiryTimestamp(); + } + + @Override + public void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException { + try { + var responseFuture = internalClient.converseAsync(converseRequest); + responseListener.onResponse(responseFuture.get()); + } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + responseListener.onFailure( + new ElasticsearchException( + Strings.format("AmazonBedrock converse failure: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ) + ); + } catch (ElasticsearchException elasticsearchException) { + // just throw the exception if we have one + responseListener.onFailure(elasticsearchException); + } catch (Exception e) { + responseListener.onFailure(new ElasticsearchException("Amazon Bedrock client converse call failed", e)); + } + } + + @Override + public void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) + throws ElasticsearchException { + try { + var responseFuture = internalClient.invokeModelAsync(invokeModelRequest); + responseListener.onResponse(responseFuture.get()); + } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + responseListener.onFailure( + new ElasticsearchException( + Strings.format("AmazonBedrock invoke model failure: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ) + ); + } catch (ElasticsearchException elasticsearchException) { + // just throw the exception if we have one + responseListener.onFailure(elasticsearchException); + } catch (Exception e) { + responseListener.onFailure(new ElasticsearchException(e)); + } + } + + // allow this to be overridden for test mocks + protected AmazonBedrockRuntimeAsync createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var secretSettings = model.getSecretSettings(); + var credentials = new BasicAWSCredentials(secretSettings.accessKey.toString(), secretSettings.secretKey.toString()); + var credentialsProvider = new AWSStaticCredentialsProvider(credentials); + var clientConfig = timeout == null + ? new ClientConfiguration().withConnectionTimeout(DEFAULT_CLIENT_TIMEOUT_MS) + : new ClientConfiguration().withConnectionTimeout((int) timeout.millis()); + + var serviceSettings = model.getServiceSettings(); + + try { + SpecialPermission.check(); + AmazonBedrockRuntimeAsyncClientBuilder builder = AccessController.doPrivileged( + (PrivilegedExceptionAction) () -> AmazonBedrockRuntimeAsyncClientBuilder.standard() + .withCredentials(credentialsProvider) + .withRegion(serviceSettings.region()) + .withClientConfiguration(clientConfig) + ); + + return SocketAccess.doPrivileged(builder::build); + } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + throw new ElasticsearchException( + Strings.format("failed to create AmazonBedrockRuntime client: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ); + } catch (Exception e) { + throw new ElasticsearchException("failed to create AmazonBedrockRuntime client", e); + } + } + + private void setExpiryTimestamp() { + this.expiryTimestamp = clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES)); + } + + @Override + public boolean isExpired(Instant currentTimestampMs) { + Objects.requireNonNull(currentTimestampMs); + return currentTimestampMs.isAfter(expiryTimestamp); + } + + public void resetExpiration() { + setExpiryTimestamp(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockInferenceClient that = (AmazonBedrockInferenceClient) o; + return Objects.equals(modelKeysAndRegionHashcode, that.modelKeysAndRegionHashcode); + } + + @Override + public int hashCode() { + return this.modelKeysAndRegionHashcode; + } + + // make this package-private so only the cache can close it + @Override + void close() { + internalClient.shutdown(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java new file mode 100644 index 0000000000000..e245365c214af --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.http.IdleConnectionReaper; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.io.IOException; +import java.time.Clock; +import java.util.ArrayList; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.BiFunction; + +public final class AmazonBedrockInferenceClientCache implements AmazonBedrockClientCache { + + private final BiFunction creator; + private final Map clientsCache = new ConcurrentHashMap<>(); + private final ReentrantReadWriteLock cacheLock = new ReentrantReadWriteLock(); + + // not final for testing + private Clock clock; + + public AmazonBedrockInferenceClientCache( + BiFunction creator, + @Nullable Clock clock + ) { + this.creator = Objects.requireNonNull(creator); + this.clock = Objects.requireNonNullElse(clock, Clock.systemUTC()); + } + + public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var returnClient = internalGetOrCreateClient(model, timeout); + flushExpiredClients(); + return returnClient; + } + + private AmazonBedrockBaseClient internalGetOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + final Integer modelHash = AmazonBedrockInferenceClient.getModelKeysAndRegionHashcode(model, timeout); + cacheLock.readLock().lock(); + try { + return clientsCache.computeIfAbsent(modelHash, hashKey -> { + final AmazonBedrockBaseClient builtClient = creator.apply(model, timeout); + builtClient.setClock(clock); + builtClient.resetExpiration(); + return builtClient; + }); + } finally { + cacheLock.readLock().unlock(); + } + } + + private void flushExpiredClients() { + var currentTimestampMs = clock.instant(); + var expiredClients = new ArrayList>(); + + cacheLock.readLock().lock(); + try { + for (final Map.Entry client : clientsCache.entrySet()) { + if (client.getValue().isExpired(currentTimestampMs)) { + expiredClients.add(client); + } + } + + if (expiredClients.isEmpty()) { + return; + } + + cacheLock.readLock().unlock(); + cacheLock.writeLock().lock(); + try { + for (final Map.Entry client : expiredClients) { + var removed = clientsCache.remove(client.getKey()); + if (removed != null) { + removed.close(); + } + } + } finally { + cacheLock.readLock().lock(); + cacheLock.writeLock().unlock(); + } + } finally { + cacheLock.readLock().unlock(); + } + } + + @Override + public void close() throws IOException { + releaseCachedClients(); + } + + private void releaseCachedClients() { + // as we're closing and flushing all of these - we'll use a write lock + // across the whole operation to ensure this stays in sync + cacheLock.writeLock().lock(); + try { + // ensure all the clients are closed before we clear + for (final AmazonBedrockBaseClient client : clientsCache.values()) { + client.close(); + } + + // clear previously cached clients, they will be build lazily + clientsCache.clear(); + } finally { + cacheLock.writeLock().unlock(); + } + + // shutdown IdleConnectionReaper background thread + // it will be restarted on new client usage + IdleConnectionReaper.shutdown(); + } + + // used for testing + int clientCount() { + cacheLock.readLock().lock(); + try { + return clientsCache.size(); + } finally { + cacheLock.readLock().unlock(); + } + } + + // used for testing + void setClock(Clock newClock) { + this.clock = Objects.requireNonNull(newClock); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java new file mode 100644 index 0000000000000..e23b0274ede26 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +public class AmazonBedrockRequestSender implements Sender { + + public static class Factory { + private final ServiceComponents serviceComponents; + private final ClusterService clusterService; + + public Factory(ServiceComponents serviceComponents, ClusterService clusterService) { + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(clusterService); + } + + public Sender createSender() { + var clientCache = new AmazonBedrockInferenceClientCache(AmazonBedrockInferenceClient::create, null); + return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager())); + } + + Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) { + var sender = new AmazonBedrockRequestSender( + serviceComponents.threadPool(), + clusterService, + serviceComponents.settings(), + Objects.requireNonNull(requestSender) + ); + // ensure this is started + sender.start(); + return sender; + } + } + + private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5); + + private final ThreadPool threadPool; + private final AmazonBedrockRequestExecutorService executorService; + private final AtomicBoolean started = new AtomicBoolean(false); + private final CountDownLatch startCompleted = new CountDownLatch(1); + + protected AmazonBedrockRequestSender( + ThreadPool threadPool, + ClusterService clusterService, + Settings settings, + AmazonBedrockExecuteOnlyRequestSender requestSender + ) { + this.threadPool = Objects.requireNonNull(threadPool); + executorService = new AmazonBedrockRequestExecutorService( + threadPool, + startCompleted, + new RequestExecutorServiceSettings(settings, clusterService), + requestSender + ); + } + + @Override + public void start() { + if (started.compareAndSet(false, true)) { + // The manager must be started before the executor service. That way we guarantee that the http client + // is ready prior to the service attempting to use the http client to send a request + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(executorService::start); + waitForStartToComplete(); + } + } + + private void waitForStartToComplete() { + try { + if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("Amazon Bedrock sender startup did not complete in time"); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Amazon Bedrock sender interrupted while waiting for startup to complete"); + } + } + + @Override + public void send( + RequestManager requestCreator, + InferenceInputs inferenceInputs, + TimeValue timeout, + ActionListener listener + ) { + assert started.get() : "Amazon Bedrock request sender: call start() before sending a request"; + waitForStartToComplete(); + + if (requestCreator instanceof AmazonBedrockRequestManager amazonBedrockRequestManager) { + executorService.execute(amazonBedrockRequestManager, inferenceInputs, timeout, listener); + return; + } + + listener.onFailure(new ElasticsearchException("Amazon Bedrock request sender did not receive a valid request request manager")); + } + + @Override + public void close() throws IOException { + executorService.shutdown(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java new file mode 100644 index 0000000000000..1d8226664979c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionEntityFactory; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.util.List; +import java.util.function.Supplier; + +public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager { + private static final Logger logger = LogManager.getLogger(AmazonBedrockChatCompletionRequestManager.class); + private final AmazonBedrockChatCompletionModel model; + + public AmazonBedrockChatCompletionRequestManager( + AmazonBedrockChatCompletionModel model, + ThreadPool threadPool, + @Nullable TimeValue timeout + ) { + super(model, threadPool, timeout); + this.model = model; + } + + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, input); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + try { + requestSender.send(logger, request, HttpClientContext.create(), hasRequestCompletedFunction, responseHandler, listener); + } catch (Exception e) { + var errorMessage = Strings.format( + "Failed to send [completion] request from inference entity id [%s]", + request.getInferenceEntityId() + ); + logger.warn(errorMessage, e); + listener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..e9bc6b574865c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsEntityFactory; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class AmazonBedrockEmbeddingsRequestManager extends AmazonBedrockRequestManager { + private static final Logger logger = LogManager.getLogger(AmazonBedrockEmbeddingsRequestManager.class); + + private final AmazonBedrockEmbeddingsModel embeddingsModel; + private final Truncator truncator; + + public AmazonBedrockEmbeddingsRequestManager( + AmazonBedrockEmbeddingsModel model, + Truncator truncator, + ThreadPool threadPool, + @Nullable TimeValue timeout + ) { + super(model, threadPool, timeout); + this.embeddingsModel = model; + this.truncator = Objects.requireNonNull(truncator); + } + + @Override + public void execute( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var truncatedInput = truncate(input, serviceSettings.maxInputTokens()); + var requestEntity = AmazonBedrockEmbeddingsEntityFactory.createEntity(embeddingsModel, truncatedInput); + var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); + var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); + try { + requestSender.send(logger, request, HttpClientContext.create(), hasRequestCompletedFunction, responseHandler, listener); + } catch (Exception e) { + var errorMessage = Strings.format( + "Failed to send [text_embedding] request from inference entity id [%s]", + request.getInferenceEntityId() + ); + logger.warn(errorMessage, e); + listener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java new file mode 100644 index 0000000000000..8b4672d45c250 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecuteOnlyRequestSender; + +import java.io.IOException; +import java.util.concurrent.CountDownLatch; + +/** + * Allows this to have a public interface for Amazon Bedrock support + */ +public class AmazonBedrockRequestExecutorService extends RequestExecutorService { + + private final AmazonBedrockExecuteOnlyRequestSender requestSender; + + public AmazonBedrockRequestExecutorService( + ThreadPool threadPool, + CountDownLatch startupLatch, + RequestExecutorServiceSettings settings, + AmazonBedrockExecuteOnlyRequestSender requestSender + ) { + super(threadPool, startupLatch, settings, requestSender); + this.requestSender = requestSender; + } + + @Override + public void shutdown() { + super.shutdown(); + try { + requestSender.shutdown(); + } catch (IOException e) { + // swallow the exception + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java new file mode 100644 index 0000000000000..f75343b038368 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Objects; + +public abstract class AmazonBedrockRequestManager implements RequestManager { + + protected final ThreadPool threadPool; + protected final TimeValue timeout; + private final AmazonBedrockModel baseModel; + + protected AmazonBedrockRequestManager(AmazonBedrockModel baseModel, ThreadPool threadPool, @Nullable TimeValue timeout) { + this.baseModel = Objects.requireNonNull(baseModel); + this.threadPool = Objects.requireNonNull(threadPool); + this.timeout = timeout; + } + + @Override + public String inferenceEntityId() { + return baseModel.getInferenceEntityId(); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return baseModel.rateLimitSettings(); + } + + record RateLimitGrouping(int keyHash) { + public static AmazonBedrockRequestManager.RateLimitGrouping of(AmazonBedrockModel model) { + Objects.requireNonNull(model); + + var awsSecretSettings = model.getSecretSettings(); + + return new RateLimitGrouping(Objects.hash(awsSecretSettings.accessKey, awsSecretSettings.secretKey)); + } + } + + @Override + public Object rateLimitGrouping() { + return RateLimitGrouping.of(this.baseModel); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java new file mode 100644 index 0000000000000..829e899beba5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; + +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; + +public class AmazonBedrockJsonBuilder { + + private final ToXContent jsonWriter; + + public AmazonBedrockJsonBuilder(ToXContent jsonWriter) { + this.jsonWriter = jsonWriter; + } + + public String getStringContent() throws IOException { + try (var builder = jsonBuilder()) { + return Strings.toString(jsonWriter.toXContent(builder, null)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java new file mode 100644 index 0000000000000..83ebcb4563a8c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import com.fasterxml.jackson.core.JsonGenerator; + +import java.io.IOException; + +/** + * This is needed as the input for the Amazon Bedrock SDK does not like + * the formatting of XContent JSON output + */ +public interface AmazonBedrockJsonWriter { + JsonGenerator writeJson(JsonGenerator generator) throws IOException; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java new file mode 100644 index 0000000000000..e356212ed07fb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.net.URI; + +public abstract class AmazonBedrockRequest implements Request { + + protected final AmazonBedrockModel amazonBedrockModel; + protected final String inferenceId; + protected final TimeValue timeout; + + protected AmazonBedrockRequest(AmazonBedrockModel model, @Nullable TimeValue timeout) { + this.amazonBedrockModel = model; + this.inferenceId = model.getInferenceEntityId(); + this.timeout = timeout; + } + + protected abstract void executeRequest(AmazonBedrockBaseClient client); + + public AmazonBedrockModel model() { + return amazonBedrockModel; + } + + /** + * Amazon Bedrock uses the AWS SDK, and will not create its own Http Request + * But, this is needed for the ExecutableInferenceRequest to get the inferenceEntityId + * @return NoOp request + */ + @Override + public final HttpRequest createHttpRequest() { + return new HttpRequest(new NoOpHttpRequest(), inferenceId); + } + + /** + * Amazon Bedrock uses the AWS SDK, and will not create its own URI + * @return null + */ + @Override + public final URI getURI() { + throw new UnsupportedOperationException(); + } + + /** + * Should be overridden for text embeddings requests + * @return null + */ + @Override + public Request truncate() { + return this; + } + + /** + * Should be overridden for text embeddings requests + * @return boolean[0] + */ + @Override + public boolean[] getTruncationInfo() { + return new boolean[0]; + } + + @Override + public String getInferenceEntityId() { + return amazonBedrockModel.getInferenceEntityId(); + } + + public TimeValue timeout() { + return timeout; + } + + public abstract TaskType taskType(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java new file mode 100644 index 0000000000000..7087bb03bca5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import org.apache.http.client.methods.HttpRequestBase; + +/** + * Needed for compatibility with RequestSender + */ +public class NoOpHttpRequest extends HttpRequestBase { + @Override + public String getMethod() { + return "NOOP"; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java new file mode 100644 index 0000000000000..6e2f2f6702005 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockAI21LabsCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockAI21LabsCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java new file mode 100644 index 0000000000000..a8b0032af09c5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockAnthropicCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockAnthropicCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + if (topK == null) { + return request; + } + + String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); + return request.withAdditionalModelResponseFieldPaths(topKField); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java new file mode 100644 index 0000000000000..f86d2229d42ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.util.List; +import java.util.Objects; + +public final class AmazonBedrockChatCompletionEntityFactory { + public static AmazonBedrockConverseRequestEntity createEntity(AmazonBedrockChatCompletionModel model, List messages) { + Objects.requireNonNull(model); + Objects.requireNonNull(messages); + var serviceSettings = model.getServiceSettings(); + var taskSettings = model.getTaskSettings(); + switch (serviceSettings.provider()) { + case AI21LABS -> { + return new AmazonBedrockAI21LabsCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.maxNewTokens() + ); + } + case AMAZONTITAN -> { + return new AmazonBedrockTitanCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.maxNewTokens() + ); + } + case ANTHROPIC -> { + return new AmazonBedrockAnthropicCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.topK(), + taskSettings.maxNewTokens() + ); + } + case COHERE -> { + return new AmazonBedrockCohereCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.topK(), + taskSettings.maxNewTokens() + ); + } + case META -> { + return new AmazonBedrockMetaCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.maxNewTokens() + ); + } + case MISTRAL -> { + return new AmazonBedrockMistralCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.topK(), + taskSettings.maxNewTokens() + ); + } + default -> { + return null; + } + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java new file mode 100644 index 0000000000000..f02f05f2d3b17 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class AmazonBedrockChatCompletionRequest extends AmazonBedrockRequest { + public static final String USER_ROLE = "user"; + private final AmazonBedrockConverseRequestEntity requestEntity; + private AmazonBedrockChatCompletionResponseListener listener; + + public AmazonBedrockChatCompletionRequest( + AmazonBedrockChatCompletionModel model, + AmazonBedrockConverseRequestEntity requestEntity, + @Nullable TimeValue timeout + ) { + super(model, timeout); + this.requestEntity = Objects.requireNonNull(requestEntity); + } + + @Override + protected void executeRequest(AmazonBedrockBaseClient client) { + var converseRequest = getConverseRequest(); + + try { + SocketAccess.doPrivileged(() -> client.converse(converseRequest, listener)); + } catch (IOException e) { + listener.onFailure(new RuntimeException(e)); + } + } + + @Override + public TaskType taskType() { + return TaskType.COMPLETION; + } + + private ConverseRequest getConverseRequest() { + var converseRequest = new ConverseRequest().withModelId(amazonBedrockModel.model()); + converseRequest = requestEntity.addMessages(converseRequest); + converseRequest = requestEntity.addInferenceConfig(converseRequest); + converseRequest = requestEntity.addAdditionalModelFields(converseRequest); + return converseRequest; + } + + public void executeChatCompletionRequest( + AmazonBedrockBaseClient awsBedrockClient, + AmazonBedrockChatCompletionResponseListener chatCompletionResponseListener + ) { + this.listener = chatCompletionResponseListener; + this.executeRequest(awsBedrockClient); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java new file mode 100644 index 0000000000000..17a264ef820ff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockCohereCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockCohereCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + if (topK == null) { + return request; + } + + String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); + return request.withAdditionalModelResponseFieldPaths(topKField); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java new file mode 100644 index 0000000000000..fbd55e76e509b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; + +public interface AmazonBedrockConverseRequestEntity { + ConverseRequest addMessages(ConverseRequest request); + + ConverseRequest addInferenceConfig(ConverseRequest request); + + ConverseRequest addAdditionalModelFields(ConverseRequest request); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java new file mode 100644 index 0000000000000..2cfb56a94b319 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ContentBlock; +import com.amazonaws.services.bedrockruntime.model.Message; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest.USER_ROLE; + +public final class AmazonBedrockConverseUtils { + + public static List getConverseMessageList(List messages) { + List messageList = new ArrayList<>(); + for (String message : messages) { + var messageContent = new ContentBlock().withText(message); + var returnMessage = (new Message()).withRole(USER_ROLE).withContent(messageContent); + messageList.add(returnMessage); + } + return messageList; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java new file mode 100644 index 0000000000000..cdabdd4cbebff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockMetaCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockMetaCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java new file mode 100644 index 0000000000000..c68eaa1b81f54 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockMistralCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockMistralCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + if (topK == null) { + return request; + } + + String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); + return request.withAdditionalModelResponseFieldPaths(topKField); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java new file mode 100644 index 0000000000000..d56035b80e9ef --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockTitanCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockTitanCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..edca5bc1bdf9c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AmazonBedrockCohereEmbeddingsRequestEntity(List input) implements ToXContentObject { + + private static final String TEXTS_FIELD = "texts"; + private static final String INPUT_TYPE_FIELD = "input_type"; + private static final String INPUT_TYPE_SEARCH_DOCUMENT = "search_document"; + + public AmazonBedrockCohereEmbeddingsRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXTS_FIELD, input); + builder.field(INPUT_TYPE_FIELD, INPUT_TYPE_SEARCH_DOCUMENT); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java new file mode 100644 index 0000000000000..a31b033507264 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.Objects; + +public final class AmazonBedrockEmbeddingsEntityFactory { + public static ToXContent createEntity(AmazonBedrockEmbeddingsModel model, Truncator.TruncationResult truncationResult) { + Objects.requireNonNull(model); + Objects.requireNonNull(truncationResult); + + var serviceSettings = model.getServiceSettings(); + + var truncatedInput = truncationResult.input(); + if (truncatedInput == null || truncatedInput.isEmpty()) { + throw new ElasticsearchException("[input] cannot be null or empty"); + } + + switch (serviceSettings.provider()) { + case AMAZONTITAN -> { + if (truncatedInput.size() > 1) { + throw new ElasticsearchException("[input] cannot contain more than one string"); + } + return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0)); + } + case COHERE -> { + return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput); + } + default -> { + return null; + } + } + + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java new file mode 100644 index 0000000000000..96d3b3a3cc057 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseListener; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest { + private final AmazonBedrockEmbeddingsModel embeddingsModel; + private final ToXContent requestEntity; + private final Truncator truncator; + private final Truncator.TruncationResult truncationResult; + private final AmazonBedrockProvider provider; + private ActionListener listener = null; + + public AmazonBedrockEmbeddingsRequest( + Truncator truncator, + Truncator.TruncationResult input, + AmazonBedrockEmbeddingsModel model, + ToXContent requestEntity, + @Nullable TimeValue timeout + ) { + super(model, timeout); + this.truncator = Objects.requireNonNull(truncator); + this.truncationResult = Objects.requireNonNull(input); + this.requestEntity = Objects.requireNonNull(requestEntity); + this.embeddingsModel = model; + this.provider = model.provider(); + } + + public AmazonBedrockProvider provider() { + return provider; + } + + @Override + protected void executeRequest(AmazonBedrockBaseClient client) { + try { + var jsonBuilder = new AmazonBedrockJsonBuilder(requestEntity); + var bodyAsString = jsonBuilder.getStringContent(); + + var charset = StandardCharsets.UTF_8; + var bodyBuffer = charset.encode(bodyAsString); + + var invokeModelRequest = new InvokeModelRequest().withModelId(embeddingsModel.model()).withBody(bodyBuffer); + + SocketAccess.doPrivileged(() -> client.invokeModel(invokeModelRequest, listener)); + } catch (IOException e) { + listener.onFailure(new RuntimeException(e)); + } + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + public void executeEmbeddingsRequest( + AmazonBedrockBaseClient awsBedrockClient, + AmazonBedrockEmbeddingsResponseListener embeddingsResponseListener + ) { + this.listener = embeddingsResponseListener; + this.executeRequest(awsBedrockClient); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..f55edd0442913 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText) implements ToXContentObject { + + private static final String INPUT_TEXT_FIELD = "inputText"; + + public AmazonBedrockTitanEmbeddingsRequestEntity { + Objects.requireNonNull(inputText); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_TEXT_FIELD, inputText); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java new file mode 100644 index 0000000000000..54b05137acda3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; + +public abstract class AmazonBedrockResponse { + public abstract InferenceServiceResults accept(AmazonBedrockRequest request); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java new file mode 100644 index 0000000000000..9dc15ea667c1d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +public abstract class AmazonBedrockResponseHandler implements ResponseHandler { + @Override + public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + // do nothing as the AWS SDK will take care of validation for us + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java new file mode 100644 index 0000000000000..ce4d6d1dea655 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; + +import java.util.Objects; + +public class AmazonBedrockResponseListener { + protected final AmazonBedrockRequest request; + protected final ActionListener inferenceResultsListener; + protected final AmazonBedrockResponseHandler responseHandler; + + public AmazonBedrockResponseListener( + AmazonBedrockRequest request, + AmazonBedrockResponseHandler responseHandler, + ActionListener inferenceResultsListener + ) { + this.request = Objects.requireNonNull(request); + this.responseHandler = Objects.requireNonNull(responseHandler); + this.inferenceResultsListener = Objects.requireNonNull(inferenceResultsListener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java new file mode 100644 index 0000000000000..5b3872e2c416a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponse; + +import java.util.ArrayList; + +public class AmazonBedrockChatCompletionResponse extends AmazonBedrockResponse { + + private final ConverseResult result; + + public AmazonBedrockChatCompletionResponse(ConverseResult responseResult) { + this.result = responseResult; + } + + @Override + public InferenceServiceResults accept(AmazonBedrockRequest request) { + if (request instanceof AmazonBedrockChatCompletionRequest asChatCompletionRequest) { + return fromResponse(result); + } + + throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); + } + + public static ChatCompletionResults fromResponse(ConverseResult response) { + var responseMessage = response.getOutput().getMessage(); + + var messageContents = responseMessage.getContent(); + var resultTexts = new ArrayList(); + for (var messageContent : messageContents) { + resultTexts.add(new ChatCompletionResults.Result(messageContent.getText())); + } + + return new ChatCompletionResults(resultTexts); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..a24f54c50eef3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; + +public class AmazonBedrockChatCompletionResponseHandler extends AmazonBedrockResponseHandler { + + private ConverseResult responseResult; + + public AmazonBedrockChatCompletionResponseHandler() {} + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + var response = new AmazonBedrockChatCompletionResponse(responseResult); + return response.accept((AmazonBedrockRequest) request); + } + + @Override + public String getRequestType() { + return "Amazon Bedrock Chat Completion"; + } + + public void acceptChatCompletionResponseObject(ConverseResult response) { + this.responseResult = response; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java new file mode 100644 index 0000000000000..be03ba84571eb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener; + +public class AmazonBedrockChatCompletionResponseListener extends AmazonBedrockResponseListener implements ActionListener { + + public AmazonBedrockChatCompletionResponseListener( + AmazonBedrockChatCompletionRequest request, + AmazonBedrockResponseHandler responseHandler, + ActionListener inferenceResultsListener + ) { + super(request, responseHandler, inferenceResultsListener); + } + + @Override + public void onResponse(ConverseResult result) { + ((AmazonBedrockChatCompletionResponseHandler) responseHandler).acceptChatCompletionResponseObject(result); + inferenceResultsListener.onResponse(responseHandler.parseResult(request, null)); + } + + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException(e); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java new file mode 100644 index 0000000000000..83fa790acbe68 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.XContentUtils; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponse; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AmazonBedrockEmbeddingsResponse extends AmazonBedrockResponse { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Amazon Bedrock embeddings response"; + private final InvokeModelResult result; + + public AmazonBedrockEmbeddingsResponse(InvokeModelResult invokeModelResult) { + this.result = invokeModelResult; + } + + @Override + public InferenceServiceResults accept(AmazonBedrockRequest request) { + if (request instanceof AmazonBedrockEmbeddingsRequest asEmbeddingsRequest) { + return fromResponse(result, asEmbeddingsRequest.provider()); + } + + throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); + } + + public static InferenceTextEmbeddingFloatResults fromResponse(InvokeModelResult response, AmazonBedrockProvider provider) { + var charset = StandardCharsets.UTF_8; + var bodyText = String.valueOf(charset.decode(response.getBody())); + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, bodyText)) { + // move to the first token + jsonParser.nextToken(); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + var embeddingList = parseEmbeddings(jsonParser, provider); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + } catch (IOException e) { + throw new ElasticsearchException(e); + } + } + + private static List parseEmbeddings( + XContentParser jsonParser, + AmazonBedrockProvider provider + ) throws IOException { + switch (provider) { + case AMAZONTITAN -> { + return parseTitanEmbeddings(jsonParser); + } + case COHERE -> { + return parseCohereEmbeddings(jsonParser); + } + default -> throw new IOException("Unsupported provider [" + provider + "]"); + } + } + + private static List parseTitanEmbeddings(XContentParser parser) + throws IOException { + /* + Titan response: + { + "embedding": [float, float, ...], + "inputTextTokenCount": int + } + */ + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + var embeddingValues = InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + return List.of(embeddingValues); + } + + private static List parseCohereEmbeddings(XContentParser parser) + throws IOException { + /* + Cohere response: + { + "embeddings": [ + [< array of 1024 floats >], + ... + ], + "id": string, + "response_type" : "embeddings_floats", + "texts": [string] + } + */ + positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = parseList( + parser, + AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem + ); + + return embeddingList; + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseCohereEmbeddingsListItem(XContentParser parser) + throws IOException { + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..a3fb68ee23486 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; + +public class AmazonBedrockEmbeddingsResponseHandler extends AmazonBedrockResponseHandler { + + private InvokeModelResult invokeModelResult; + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + var responseParser = new AmazonBedrockEmbeddingsResponse(invokeModelResult); + return responseParser.accept((AmazonBedrockRequest) request); + } + + @Override + public String getRequestType() { + return "Amazon Bedrock Embeddings"; + } + + public void acceptEmbeddingsResult(InvokeModelResult result) { + this.invokeModelResult = result; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java new file mode 100644 index 0000000000000..36519ae31ff60 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener; + +public class AmazonBedrockEmbeddingsResponseListener extends AmazonBedrockResponseListener implements ActionListener { + + public AmazonBedrockEmbeddingsResponseListener( + AmazonBedrockEmbeddingsRequest request, + AmazonBedrockResponseHandler responseHandler, + ActionListener inferenceResultsListener + ) { + super(request, responseHandler, inferenceResultsListener); + } + + @Override + public void onResponse(InvokeModelResult result) { + ((AmazonBedrockEmbeddingsResponseHandler) responseHandler).acceptEmbeddingsResult(result); + inferenceResultsListener.onResponse(responseHandler.parseResult(request, null)); + } + + @Override + public void onFailure(Exception e) { + inferenceResultsListener.onFailure(e); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java new file mode 100644 index 0000000000000..1755dac2ac13f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +public class AmazonBedrockConstants { + public static final String ACCESS_KEY_FIELD = "access_key"; + public static final String SECRET_KEY_FIELD = "secret_key"; + public static final String REGION_FIELD = "region"; + public static final String MODEL_FIELD = "model"; + public static final String PROVIDER_FIELD = "provider"; + + public static final String TEMPERATURE_FIELD = "temperature"; + public static final String TOP_P_FIELD = "top_p"; + public static final String TOP_K_FIELD = "top_k"; + public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens"; + + public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0; + public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0; + + public static final int DEFAULT_MAX_CHUNK_SIZE = 2048; + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java new file mode 100644 index 0000000000000..13ca8bd7bd749 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; + +public abstract class AmazonBedrockModel extends Model { + + protected String region; + protected String model; + protected AmazonBedrockProvider provider; + protected RateLimitSettings rateLimitSettings; + + protected AmazonBedrockModel(ModelConfigurations modelConfigurations, ModelSecrets secrets) { + super(modelConfigurations, secrets); + setPropertiesFromServiceSettings((AmazonBedrockServiceSettings) modelConfigurations.getServiceSettings()); + } + + protected AmazonBedrockModel(Model model, TaskSettings taskSettings) { + super(model, taskSettings); + + if (model instanceof AmazonBedrockModel bedrockModel) { + setPropertiesFromServiceSettings(bedrockModel.getServiceSettings()); + } + } + + protected AmazonBedrockModel(Model model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + if (serviceSettings instanceof AmazonBedrockServiceSettings bedrockServiceSettings) { + setPropertiesFromServiceSettings(bedrockServiceSettings); + } + } + + protected AmazonBedrockModel(ModelConfigurations modelConfigurations) { + super(modelConfigurations); + setPropertiesFromServiceSettings((AmazonBedrockServiceSettings) modelConfigurations.getServiceSettings()); + } + + public String region() { + return region; + } + + public String model() { + return model; + } + + public AmazonBedrockProvider provider() { + return provider; + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + private void setPropertiesFromServiceSettings(AmazonBedrockServiceSettings serviceSettings) { + this.region = serviceSettings.region(); + this.model = serviceSettings.model(); + this.provider = serviceSettings.provider(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + public abstract ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings); + + @Override + public AmazonBedrockServiceSettings getServiceSettings() { + return (AmazonBedrockServiceSettings) super.getServiceSettings(); + } + + @Override + public AmazonBedrockSecretSettings getSecretSettings() { + return (AmazonBedrockSecretSettings) super.getSecretSettings(); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java new file mode 100644 index 0000000000000..340a5a65f0969 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import java.util.Locale; + +public enum AmazonBedrockProvider { + AMAZONTITAN, + ANTHROPIC, + AI21LABS, + COHERE, + META, + MISTRAL; + + public static String NAME = "amazon_bedrock_provider"; + + public static AmazonBedrockProvider fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java new file mode 100644 index 0000000000000..28b10ef294bda --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.DEFAULT_MAX_CHUNK_SIZE; + +public final class AmazonBedrockProviderCapabilities { + private static final List embeddingProviders = List.of( + AmazonBedrockProvider.AMAZONTITAN, + AmazonBedrockProvider.COHERE + ); + + private static final List chatCompletionProviders = List.of( + AmazonBedrockProvider.AMAZONTITAN, + AmazonBedrockProvider.ANTHROPIC, + AmazonBedrockProvider.AI21LABS, + AmazonBedrockProvider.COHERE, + AmazonBedrockProvider.META, + AmazonBedrockProvider.MISTRAL + ); + + private static final List chatCompletionProvidersWithTopK = List.of( + AmazonBedrockProvider.ANTHROPIC, + AmazonBedrockProvider.COHERE, + AmazonBedrockProvider.MISTRAL + ); + + private static final Map embeddingsDefaultSimilarityMeasure = Map.of( + AmazonBedrockProvider.AMAZONTITAN, + SimilarityMeasure.COSINE, + AmazonBedrockProvider.COHERE, + SimilarityMeasure.DOT_PRODUCT + ); + + private static final Map embeddingsDefaultChunkSize = Map.of( + AmazonBedrockProvider.AMAZONTITAN, + 8192, + AmazonBedrockProvider.COHERE, + 2048 + ); + + private static final Map embeddingsMaxBatchSize = Map.of( + AmazonBedrockProvider.AMAZONTITAN, + 1, + AmazonBedrockProvider.COHERE, + 96 + ); + + public static boolean providerAllowsTaskType(AmazonBedrockProvider provider, TaskType taskType) { + switch (taskType) { + case COMPLETION -> { + return chatCompletionProviders.contains(provider); + } + case TEXT_EMBEDDING -> { + return embeddingProviders.contains(provider); + } + default -> { + return false; + } + } + } + + public static boolean chatCompletionProviderHasTopKParameter(AmazonBedrockProvider provider) { + return chatCompletionProvidersWithTopK.contains(provider); + } + + public static SimilarityMeasure getProviderDefaultSimilarityMeasure(AmazonBedrockProvider provider) { + if (embeddingsDefaultSimilarityMeasure.containsKey(provider)) { + return embeddingsDefaultSimilarityMeasure.get(provider); + } + + return SimilarityMeasure.COSINE; + } + + public static int getEmbeddingsProviderDefaultChunkSize(AmazonBedrockProvider provider) { + if (embeddingsDefaultChunkSize.containsKey(provider)) { + return embeddingsDefaultChunkSize.get(provider); + } + + return DEFAULT_MAX_CHUNK_SIZE; + } + + public static int getEmbeddingsMaxBatchSize(AmazonBedrockProvider provider) { + if (embeddingsMaxBatchSize.containsKey(provider)) { + return embeddingsMaxBatchSize.get(provider); + } + + return 1; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java new file mode 100644 index 0000000000000..9e6328ce1c358 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.SECRET_KEY_FIELD; + +public class AmazonBedrockSecretSettings implements SecretSettings { + public static final String NAME = "amazon_bedrock_secret_settings"; + + public final SecureString accessKey; + public final SecureString secretKey; + + public static AmazonBedrockSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureAccessKey = extractRequiredSecureString( + map, + ACCESS_KEY_FIELD, + ModelSecrets.SECRET_SETTINGS, + validationException + ); + SecureString secureSecretKey = extractRequiredSecureString( + map, + SECRET_KEY_FIELD, + ModelSecrets.SECRET_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockSecretSettings(secureAccessKey, secureSecretKey); + } + + public AmazonBedrockSecretSettings(SecureString accessKey, SecureString secretKey) { + this.accessKey = Objects.requireNonNull(accessKey); + this.secretKey = Objects.requireNonNull(secretKey); + } + + public AmazonBedrockSecretSettings(StreamInput in) throws IOException { + this.accessKey = in.readSecureString(); + this.secretKey = in.readSecureString(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeSecureString(accessKey); + out.writeSecureString(secretKey); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(ACCESS_KEY_FIELD, accessKey.toString()); + builder.field(SECRET_KEY_FIELD, secretKey.toString()); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + AmazonBedrockSecretSettings that = (AmazonBedrockSecretSettings) object; + return Objects.equals(accessKey, that.accessKey) && Objects.equals(secretKey, that.secretKey); + } + + @Override + public int hashCode() { + return Objects.hash(accessKey, secretKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java new file mode 100644 index 0000000000000..dadcc8a40245e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -0,0 +1,350 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionCreator; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.chatCompletionProviderHasTopKParameter; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getEmbeddingsMaxBatchSize; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.providerAllowsTaskType; + +public class AmazonBedrockService extends SenderService { + public static final String NAME = "amazonbedrock"; + + private final Sender amazonBedrockSender; + + public AmazonBedrockService( + HttpRequestSender.Factory httpSenderFactory, + AmazonBedrockRequestSender.Factory amazonBedrockFactory, + ServiceComponents serviceComponents + ) { + super(httpSenderFactory, serviceComponents); + this.amazonBedrockSender = amazonBedrockFactory.createSender(); + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); + if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { + var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Amazon Bedrock service does not support inference with query input"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + ActionListener inferListener = listener.delegateFailureAndWrap( + (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response)) + ); + + var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); + if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { + var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); + var batchedRequests = new EmbeddingRequestChunker(input, maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT) + .batchRequestsWithListeners(listener); + for (var request : batchedRequests) { + var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, inferListener); + } + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { + if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { + return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults); + } else if (inferenceResults instanceof ErrorInferenceResults error) { + return List.of(new ErrorChunkedInferenceResults(error.getException())); + } else { + throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName()); + } + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + AmazonBedrockModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME), + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(modelId, NAME), + ConfigurationParseContext.PERSISTENT + ); + } + + private static AmazonBedrockModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING -> { + var model = new AmazonBedrockEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider()); + return model; + } + case COMPLETION -> { + var model = new AmazonBedrockChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + checkProviderForTask(TaskType.COMPLETION, model.provider()); + checkChatCompletionProviderForTopKParameter(model); + return model; + } + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private AmazonBedrockEmbeddingsModel updateModelWithEmbeddingDetails(AmazonBedrockEmbeddingsModel model, int embeddingSize) { + AmazonBedrockEmbeddingsServiceSettings serviceSettings = model.getServiceSettings(); + if (serviceSettings.dimensionsSetByUser() + && serviceSettings.dimensions() != null + && serviceSettings.dimensions() != embeddingSize) { + throw new ElasticsearchStatusException( + Strings.format( + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " + + "Please recreate the [%s] configuration with the correct dimensions", + embeddingSize, + serviceSettings.dimensions(), + model.getConfigurations().getInferenceEntityId() + ), + RestStatus.BAD_REQUEST + ); + } + + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? getProviderDefaultSimilarityMeasure(model.provider()) : similarityFromModel; + + AmazonBedrockEmbeddingsServiceSettings settingsToUse = new AmazonBedrockEmbeddingsServiceSettings( + serviceSettings.region(), + serviceSettings.model(), + serviceSettings.provider(), + embeddingSize, + serviceSettings.dimensionsSetByUser(), + serviceSettings.maxInputTokens(), + similarityToUse, + serviceSettings.rateLimitSettings() + ); + + return new AmazonBedrockEmbeddingsModel(model, settingsToUse); + } + + private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvider provider) { + if (providerAllowsTaskType(provider, taskType) == false) { + throw new ElasticsearchStatusException( + Strings.format("The [%s] task type for provider [%s] is not available", taskType, provider), + RestStatus.BAD_REQUEST + ); + } + } + + private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) { + var taskSettings = model.getTaskSettings(); + if (taskSettings.topK() != null) { + if (chatCompletionProviderHasTopKParameter(model.provider()) == false) { + throw new ElasticsearchStatusException( + Strings.format("The [%s] task parameter is not available for provider [%s]", TOP_K_FIELD, model.provider()), + RestStatus.BAD_REQUEST + ); + } + } + } + + @Override + public void close() throws IOException { + super.close(); + IOUtils.closeWhileHandlingException(amazonBedrockSender); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java new file mode 100644 index 0000000000000..13c7c0a8c5938 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; + +public abstract class AmazonBedrockServiceSettings extends FilteredXContentObject implements ServiceSettings { + + protected static final String AMAZON_BEDROCK_BASE_NAME = "amazon_bedrock"; + + protected final String region; + protected final String model; + protected final AmazonBedrockProvider provider; + protected final RateLimitSettings rateLimitSettings; + + // the default requests per minute are defined as per-model in the "Runtime quotas" on AWS + // see: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html + // setting this to 240 requests per minute (4 requests / sec) is a sane default for us as it should be enough for + // decent throughput without exceeding the minimal for _most_ items. The user should consult + // the table above if using a model that might have a lesser limit (e.g. Anthropic Claude 3.5) + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240); + + protected static AmazonBedrockServiceSettings.BaseAmazonBedrockCommonSettings fromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + String model = extractRequiredString(map, MODEL_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + String region = extractRequiredString(map, REGION_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + AmazonBedrockProvider provider = extractRequiredEnum( + map, + PROVIDER_FIELD, + ModelConfigurations.SERVICE_SETTINGS, + AmazonBedrockProvider::fromString, + EnumSet.allOf(AmazonBedrockProvider.class), + validationException + ); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AMAZON_BEDROCK_BASE_NAME, + context + ); + + return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings); + } + + protected record BaseAmazonBedrockCommonSettings( + String region, + String model, + AmazonBedrockProvider provider, + @Nullable RateLimitSettings rateLimitSettings + ) {} + + protected AmazonBedrockServiceSettings(StreamInput in) throws IOException { + this.region = in.readString(); + this.model = in.readString(); + this.provider = in.readEnum(AmazonBedrockProvider.class); + this.rateLimitSettings = new RateLimitSettings(in); + } + + protected AmazonBedrockServiceSettings( + String region, + String model, + AmazonBedrockProvider provider, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.region = Objects.requireNonNull(region); + this.model = Objects.requireNonNull(model); + this.provider = Objects.requireNonNull(provider); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + public String region() { + return region; + } + + public String model() { + return model; + } + + public AmazonBedrockProvider provider() { + return provider; + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(region); + out.writeString(model); + out.writeEnum(provider); + rateLimitSettings.writeTo(out); + } + + public void addBaseXContent(XContentBuilder builder, Params params) throws IOException { + toXContentFragmentOfExposedFields(builder, params); + } + + protected void addXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(REGION_FIELD, region); + builder.field(MODEL_FIELD, model); + builder.field(PROVIDER_FIELD, provider.name()); + rateLimitSettings.toXContent(builder, params); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java new file mode 100644 index 0000000000000..27dc607d671aa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; + +import java.util.Map; + +public class AmazonBedrockChatCompletionModel extends AmazonBedrockModel { + + public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionModel completionModel, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return completionModel; + } + + var requestTaskSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(taskSettings); + var taskSettingsToUse = AmazonBedrockChatCompletionTaskSettings.of(completionModel.getTaskSettings(), requestTaskSettings); + return new AmazonBedrockChatCompletionModel(completionModel, taskSettingsToUse); + } + + public AmazonBedrockChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String name, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + name, + AmazonBedrockChatCompletionServiceSettings.fromMap(serviceSettings, context), + AmazonBedrockChatCompletionTaskSettings.fromMap(taskSettings), + AmazonBedrockSecretSettings.fromMap(secretSettings) + ); + } + + public AmazonBedrockChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + AmazonBedrockChatCompletionServiceSettings serviceSettings, + AmazonBedrockChatCompletionTaskSettings taskSettings, + AmazonBedrockSecretSettings secrets + ) { + super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + } + + public AmazonBedrockChatCompletionModel(Model model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + @Override + public AmazonBedrockChatCompletionServiceSettings getServiceSettings() { + return (AmazonBedrockChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public AmazonBedrockChatCompletionTaskSettings getTaskSettings() { + return (AmazonBedrockChatCompletionTaskSettings) super.getTaskSettings(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java new file mode 100644 index 0000000000000..5985dcd56c5d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalDoubleInRange; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MIN_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; + +public record AmazonBedrockChatCompletionRequestTaskSettings( + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens +) { + + public static final AmazonBedrockChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AmazonBedrockChatCompletionRequestTaskSettings( + null, + null, + null, + null + ); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link AmazonBedrockChatCompletionRequestTaskSettings} + */ + public static AmazonBedrockChatCompletionRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return AmazonBedrockChatCompletionRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + var temperature = extractOptionalDoubleInRange( + map, + TEMPERATURE_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var topP = extractOptionalDoubleInRange( + map, + TOP_P_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var topK = extractOptionalDoubleInRange( + map, + TOP_K_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Integer maxNewTokens = extractOptionalPositiveInteger( + map, + MAX_NEW_TOKENS_FIELD, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockChatCompletionRequestTaskSettings(temperature, topP, topK, maxNewTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..fc3d09c6eea7a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AmazonBedrockChatCompletionServiceSettings extends AmazonBedrockServiceSettings { + public static final String NAME = "amazon_bedrock_chat_completion_service_settings"; + + public static AmazonBedrockChatCompletionServiceSettings fromMap( + Map serviceSettings, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + var baseSettings = AmazonBedrockServiceSettings.fromMap(serviceSettings, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockChatCompletionServiceSettings( + baseSettings.region(), + baseSettings.model(), + baseSettings.provider(), + baseSettings.rateLimitSettings() + ); + } + + public AmazonBedrockChatCompletionServiceSettings( + String region, + String model, + AmazonBedrockProvider provider, + RateLimitSettings rateLimitSettings + ) { + super(region, model, provider, rateLimitSettings); + } + + public AmazonBedrockChatCompletionServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + super.addBaseXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.addXContentFragmentOfExposedFields(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockChatCompletionServiceSettings that = (AmazonBedrockChatCompletionServiceSettings) o; + + return Objects.equals(region, that.region) + && Objects.equals(provider, that.provider) + && Objects.equals(model, that.model) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(region, model, provider, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java new file mode 100644 index 0000000000000..e689e68794e1f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalDoubleInRange; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MIN_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; + +public class AmazonBedrockChatCompletionTaskSettings implements TaskSettings { + public static final String NAME = "amazon_bedrock_chat_completion_task_settings"; + + public static final AmazonBedrockChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AmazonBedrockChatCompletionRequestTaskSettings( + null, + null, + null, + null + ); + + public static AmazonBedrockChatCompletionTaskSettings fromMap(Map settings) { + ValidationException validationException = new ValidationException(); + + Double temperature = extractOptionalDoubleInRange( + settings, + TEMPERATURE_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Double topP = extractOptionalDoubleInRange( + settings, + TOP_P_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Double topK = extractOptionalDoubleInRange( + settings, + TOP_K_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Integer maxNewTokens = extractOptionalPositiveInteger( + settings, + MAX_NEW_TOKENS_FIELD, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens); + } + + public static AmazonBedrockChatCompletionTaskSettings of( + AmazonBedrockChatCompletionTaskSettings originalSettings, + AmazonBedrockChatCompletionRequestTaskSettings requestSettings + ) { + var temperature = requestSettings.temperature() == null ? originalSettings.temperature() : requestSettings.temperature(); + var topP = requestSettings.topP() == null ? originalSettings.topP() : requestSettings.topP(); + var topK = requestSettings.topK() == null ? originalSettings.topK() : requestSettings.topK(); + var maxNewTokens = requestSettings.maxNewTokens() == null ? originalSettings.maxNewTokens() : requestSettings.maxNewTokens(); + + return new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens); + } + + private final Double temperature; + private final Double topP; + private final Double topK; + private final Integer maxNewTokens; + + public AmazonBedrockChatCompletionTaskSettings( + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens + ) { + this.temperature = temperature; + this.topP = topP; + this.topK = topK; + this.maxNewTokens = maxNewTokens; + } + + public AmazonBedrockChatCompletionTaskSettings(StreamInput in) throws IOException { + this.temperature = in.readOptionalDouble(); + this.topP = in.readOptionalDouble(); + this.topK = in.readOptionalDouble(); + this.maxNewTokens = in.readOptionalVInt(); + } + + public Double temperature() { + return temperature; + } + + public Double topP() { + return topP; + } + + public Double topK() { + return topK; + } + + public Integer maxNewTokens() { + return maxNewTokens; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalDouble(temperature); + out.writeOptionalDouble(topP); + out.writeOptionalDouble(topK); + out.writeOptionalVInt(maxNewTokens); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + if (temperature != null) { + builder.field(TEMPERATURE_FIELD, temperature); + } + if (topP != null) { + builder.field(TOP_P_FIELD, topP); + } + if (topK != null) { + builder.field(TOP_K_FIELD, topK); + } + if (maxNewTokens != null) { + builder.field(MAX_NEW_TOKENS_FIELD, maxNewTokens); + } + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockChatCompletionTaskSettings that = (AmazonBedrockChatCompletionTaskSettings) o; + return Objects.equals(temperature, that.temperature) + && Objects.equals(topP, that.topP) + && Objects.equals(topK, that.topK) + && Objects.equals(maxNewTokens, that.maxNewTokens); + } + + @Override + public int hashCode() { + return Objects.hash(temperature, topP, topK, maxNewTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java new file mode 100644 index 0000000000000..0e3a954a03279 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; + +import java.util.Map; + +public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { + + public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings) { + if (taskSettings != null && taskSettings.isEmpty() == false) { + // no task settings allowed + var validationException = new ValidationException(); + validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings"); + throw validationException; + } + + return embeddingsModel; + } + + public AmazonBedrockEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context), + new EmptyTaskSettings(), + AmazonBedrockSecretSettings.fromMap(secretSettings) + ); + } + + public AmazonBedrockEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + AmazonBedrockEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + AmazonBedrockSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelSecrets(secrets) + ); + } + + public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + @Override + public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() { + return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..4bf037558c618 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java @@ -0,0 +1,220 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; + +public class AmazonBedrockEmbeddingsServiceSettings extends AmazonBedrockServiceSettings { + public static final String NAME = "amazon_bedrock_embeddings_service_settings"; + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + private final Integer dimensions; + private final Boolean dimensionsSetByUser; + private final Integer maxInputTokens; + private final SimilarityMeasure similarity; + + public static AmazonBedrockEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var settings = embeddingSettingsFromMap(map, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return settings; + } + + private static AmazonBedrockEmbeddingsServiceSettings embeddingSettingsFromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + var baseSettings = AmazonBedrockServiceSettings.fromMap(map, validationException, context); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Integer maxTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + + if (dims != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = false; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } + return new AmazonBedrockEmbeddingsServiceSettings( + baseSettings.region(), + baseSettings.model(), + baseSettings.provider(), + dims, + dimensionsSetByUser, + maxTokens, + similarity, + baseSettings.rateLimitSettings() + ); + } + + public AmazonBedrockEmbeddingsServiceSettings(StreamInput in) throws IOException { + super(in); + dimensions = in.readOptionalVInt(); + dimensionsSetByUser = in.readBoolean(); + maxInputTokens = in.readOptionalVInt(); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + } + + public AmazonBedrockEmbeddingsServiceSettings( + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings + ) { + super(region, model, provider, rateLimitSettings); + this.dimensions = dimensions; + this.dimensionsSetByUser = dimensionsSetByUser; + this.maxInputTokens = maxInputTokens; + this.similarity = similarity; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalVInt(dimensions); + out.writeBoolean(dimensionsSetByUser); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(similarity); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + super.addBaseXContent(builder, params); + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.addXContentFragmentOfExposedFields(builder, params); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + return builder; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public boolean dimensionsSetByUser() { + return this.dimensionsSetByUser; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockEmbeddingsServiceSettings that = (AmazonBedrockEmbeddingsServiceSettings) o; + + return Objects.equals(region, that.region) + && Objects.equals(provider, that.provider) + && Objects.equals(model, that.model) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(region, model, provider, dimensions, dimensionsSetByUser, maxInputTokens, similarity, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy index f21a46521a7f7..a39fcf53be7f3 100644 --- a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy +++ b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy @@ -8,12 +8,18 @@ grant { // required by: com.google.api.client.json.JsonParser#parseValue + // also required by AWS SDK for client configuration permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + // required by: com.google.api.client.json.GenericJson# + // also by AWS SDK for Jackson's ObjectMapper permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + // required to add google certs to the gcs client trustore permission java.lang.RuntimePermission "setFactory"; // gcs client opens socket connections for to access repository - permission java.net.SocketPermission "*", "connect"; + // also, AWS Bedrock client opens socket connections and needs resolve for to access to resources + permission java.net.SocketPermission "*", "connect,resolve"; }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java new file mode 100644 index 0000000000000..87d3a82b4aae6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + public void shutdown() throws IOException { + terminate(threadPool); + } + + public void testEmbeddingsRequestAction() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedFloatResults = List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F })); + var mockedResult = new InferenceTextEmbeddingFloatResults(mockedFloatResults); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + null, + null, + null, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + + assertThat(sender.sendCount(), is(1)); + var sentInputs = sender.getInputs(); + assertThat(sentInputs.size(), is(1)); + assertThat(sentInputs.get(0), is("abc")); + } + } + + public void testEmbeddingsRequestAction_HandlesException() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedResult = new ElasticsearchException("mock exception"); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(sender.sendCount(), is(1)); + assertThat(sender.getInputs().size(), is(1)); + assertThat(thrownException.getMessage(), is("mock exception")); + } + } + + public void testCompletionRequestAction() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedChatCompletionResults = List.of(new ChatCompletionResults.Result("test input string")); + var mockedResult = new ChatCompletionResults(mockedChatCompletionResults); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockChatCompletionModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + null, + null, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); + + assertThat(sender.sendCount(), is(1)); + var sentInputs = sender.getInputs(); + assertThat(sentInputs.size(), is(1)); + assertThat(sentInputs.get(0), is("abc")); + } + } + + public void testChatCompletionRequestAction_HandlesException() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedResult = new ElasticsearchException("mock exception"); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockChatCompletionModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + null, + null, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(sender.sendCount(), is(1)); + assertThat(sender.getInputs().size(), is(1)); + assertThat(thrownException.getMessage(), is("mock exception")); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java new file mode 100644 index 0000000000000..9326d39cb657c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ContentBlock; +import com.amazonaws.services.bedrockruntime.model.ConverseOutput; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import com.amazonaws.services.bedrockruntime.model.Message; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockTitanCompletionRequestEntity; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockTitanEmbeddingsRequestEntity; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; + +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.Charset; +import java.util.List; + +import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockExecutorTests extends ESTestCase { + public void testExecute_EmbeddingsRequest_ForAmazonTitan() throws CharacterCodingException { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + var truncator = createTruncator(); + var truncatedInput = truncator.truncate(List.of("abc")); + var requestEntity = new AmazonBedrockTitanEmbeddingsRequestEntity("abc"); + var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, model, requestEntity, null); + var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(null, getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT), null); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockEmbeddingsExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + var result = listener.actionGet(new TimeValue(30000)); + assertNotNull(result); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); + } + + public void testExecute_EmbeddingsRequest_ForCohere() throws CharacterCodingException { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.COHERE, + "accesskey", + "secretkey" + ); + var requestEntity = new AmazonBedrockTitanEmbeddingsRequestEntity("abc"); + var truncator = createTruncator(); + var truncatedInput = truncator.truncate(List.of("abc")); + var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, model, requestEntity, null); + var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(null, getTestInvokeResult(TEST_COHERE_EMBEDDINGS_RESULT), null); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockEmbeddingsExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + var result = listener.actionGet(new TimeValue(30000)); + assertNotNull(result); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); + } + + public void testExecute_ChatCompletionRequest() throws CharacterCodingException { + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + var requestEntity = new AmazonBedrockTitanCompletionRequestEntity(List.of("abc"), null, null, 512); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(getTestConverseResult("converse result"), null, null); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + var result = listener.actionGet(new TimeValue(30000)); + assertNotNull(result); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("converse result")))); + } + + public void testExecute_FailsProperly_WithElasticsearchException() { + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + var requestEntity = new AmazonBedrockTitanCompletionRequestEntity(List.of("abc"), null, null, 512); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(null, null, new ElasticsearchException("test exception")); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + + var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(new TimeValue(30000))); + assertThat(exceptionThrown.getMessage(), containsString("Failed to send request from inference entity id [id]")); + assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); + } + + public static ConverseResult getTestConverseResult(String resultText) { + var message = new Message().withContent(new ContentBlock().withText(resultText)); + var converseOutput = new ConverseOutput().withMessage(message); + return new ConverseResult().withOutput(converseOutput); + } + + public static InvokeModelResult getTestInvokeResult(String resultJson) throws CharacterCodingException { + var result = new InvokeModelResult(); + result.setContentType("application/json"); + var encoder = Charset.forName("UTF-8").newEncoder(); + result.setBody(encoder.encode(CharBuffer.wrap(resultJson))); + return result; + } + + public static final String TEST_AMAZON_TITAN_EMBEDDINGS_RESULT = """ + { + "embedding": [0.123, 0.456, 0.678, 0.789], + "inputTextTokenCount": int + }"""; + + public static final String TEST_COHERE_EMBEDDINGS_RESULT = """ + { + "embeddings": [ + [0.123, 0.456, 0.678, 0.789] + ], + "id": string, + "response_type" : "embeddings_floats", + "texts": [string] + } + """; +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java new file mode 100644 index 0000000000000..873b2e22497c6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; + +import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockInferenceClient.CLIENT_CACHE_EXPIRY_MINUTES; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +public class AmazonBedrockInferenceClientCacheTests extends ESTestCase { + public void testCache_ReturnsSameObject() throws IOException { + AmazonBedrockInferenceClientCache cacheInstance; + try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, null)) { + cacheInstance = cache; + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId", + "testregion", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access_key", + "secret_key" + ); + + var client = cache.getOrCreateClient(model, null); + + var secondModel = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId_two", + "testregion", + "a_different_model", + AmazonBedrockProvider.COHERE, + "access_key", + "secret_key" + ); + + var secondClient = cache.getOrCreateClient(secondModel, null); + assertThat(client, sameInstance(secondClient)); + + assertThat(cache.clientCount(), is(1)); + + var thirdClient = cache.getOrCreateClient(model, null); + assertThat(client, sameInstance(thirdClient)); + + assertThat(cache.clientCount(), is(1)); + } + assertThat(cacheInstance.clientCount(), is(0)); + } + + public void testCache_ItEvictsExpiredClients() throws IOException { + var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + AmazonBedrockInferenceClientCache cacheInstance; + try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, clock)) { + cacheInstance = cache; + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId", + "testregion", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access_key", + "secret_key" + ); + + var client = cache.getOrCreateClient(model, null); + + var secondModel = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId_two", + "some_other_region", + "a_different_model", + AmazonBedrockProvider.COHERE, + "other_access_key", + "other_secret_key" + ); + + assertThat(cache.clientCount(), is(1)); + + var secondClient = cache.getOrCreateClient(secondModel, null); + assertThat(client, not(sameInstance(secondClient))); + + assertThat(cache.clientCount(), is(2)); + + // set clock to after expiry + cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES + 1)), ZoneId.systemDefault())); + + // get another client, this will ensure flushExpiredClients is called + var regetSecondClient = cache.getOrCreateClient(secondModel, null); + assertThat(secondClient, sameInstance(regetSecondClient)); + + var regetFirstClient = cache.getOrCreateClient(model, null); + assertThat(client, not(sameInstance(regetFirstClient))); + } + assertThat(cacheInstance.clientCount(), is(0)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java new file mode 100644 index 0000000000000..912967a9012d7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.io.IOException; + +public class AmazonBedrockMockClientCache implements AmazonBedrockClientCache { + private ConverseResult converseResult = null; + private InvokeModelResult invokeModelResult = null; + private ElasticsearchException exceptionToThrow = null; + + public AmazonBedrockMockClientCache() {} + + public AmazonBedrockMockClientCache( + @Nullable ConverseResult converseResult, + @Nullable InvokeModelResult invokeModelResult, + @Nullable ElasticsearchException exceptionToThrow + ) { + this.converseResult = converseResult; + this.invokeModelResult = invokeModelResult; + this.exceptionToThrow = exceptionToThrow; + } + + @Override + public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, TimeValue timeout) { + var client = (AmazonBedrockMockInferenceClient) AmazonBedrockMockInferenceClient.create(model, timeout); + client.setConverseResult(converseResult); + client.setInvokeModelResult(invokeModelResult); + client.setExceptionToThrow(exceptionToThrow); + return client; + } + + @Override + public void close() throws IOException { + // nothing to do + } + + public void setConverseResult(ConverseResult converseResult) { + this.converseResult = converseResult; + } + + public void setInvokeModelResult(InvokeModelResult invokeModelResult) { + this.invokeModelResult = invokeModelResult; + } + + public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { + this.exceptionToThrow = exceptionToThrow; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java new file mode 100644 index 0000000000000..b0df8a40e2551 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Supplier; + +public class AmazonBedrockMockExecuteRequestSender extends AmazonBedrockExecuteOnlyRequestSender { + + private Queue results = new ConcurrentLinkedQueue<>(); + private Queue> inputs = new ConcurrentLinkedQueue<>(); + private int sendCounter = 0; + + public AmazonBedrockMockExecuteRequestSender(AmazonBedrockClientCache clientCache, ThrottlerManager throttlerManager) { + super(clientCache, throttlerManager); + } + + public void enqueue(Object result) { + results.add(result); + } + + public int sendCount() { + return sendCounter; + } + + public List getInputs() { + return inputs.remove(); + } + + @Override + protected AmazonBedrockExecutor createExecutor( + AmazonBedrockRequest awsRequest, + AmazonBedrockResponseHandler awsResponse, + Logger logger, + Supplier hasRequestTimedOutFunction, + ActionListener listener + ) { + setCacheResult(); + return super.createExecutor(awsRequest, awsResponse, logger, hasRequestTimedOutFunction, listener); + } + + private void setCacheResult() { + var mockCache = (AmazonBedrockMockClientCache) this.clientCache; + var result = results.remove(); + if (result instanceof ConverseResult converseResult) { + mockCache.setConverseResult(converseResult); + return; + } + + if (result instanceof InvokeModelResult invokeModelResult) { + mockCache.setInvokeModelResult(invokeModelResult); + return; + } + + if (result instanceof ElasticsearchException exception) { + mockCache.setExceptionToThrow(exception); + return; + } + + throw new RuntimeException("Unknown result type: " + result.getClass()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java new file mode 100644 index 0000000000000..dcbf8dfcbff01 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsync; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +public class AmazonBedrockMockInferenceClient extends AmazonBedrockInferenceClient { + private ConverseResult converseResult = null; + private InvokeModelResult invokeModelResult = null; + private ElasticsearchException exceptionToThrow = null; + + private Future converseResultFuture = new MockConverseResultFuture(); + private Future invokeModelResultFuture = new MockInvokeResultFuture(); + + public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { + return new AmazonBedrockMockInferenceClient(model, timeout); + } + + protected AmazonBedrockMockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + super(model, timeout); + } + + public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { + this.exceptionToThrow = exceptionToThrow; + } + + public void setConverseResult(ConverseResult result) { + this.converseResult = result; + } + + public void setInvokeModelResult(InvokeModelResult result) { + this.invokeModelResult = result; + } + + @Override + protected AmazonBedrockRuntimeAsync createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var runtimeClient = mock(AmazonBedrockRuntimeAsync.class); + doAnswer(invocation -> invokeModelResultFuture).when(runtimeClient).invokeModelAsync(any()); + doAnswer(invocation -> converseResultFuture).when(runtimeClient).converseAsync(any()); + + return runtimeClient; + } + + @Override + void close() {} + + private class MockConverseResultFuture implements Future { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public ConverseResult get() throws InterruptedException, ExecutionException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return converseResult; + } + + @Override + public ConverseResult get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return converseResult; + } + } + + private class MockInvokeResultFuture implements Future { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public InvokeModelResult get() throws InterruptedException, ExecutionException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return invokeModelResult; + } + + @Override + public InvokeModelResult get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return invokeModelResult; + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java new file mode 100644 index 0000000000000..e68beaf4c1eb5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; + +import java.io.IOException; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +public class AmazonBedrockMockRequestSender implements Sender { + + public static class Factory extends AmazonBedrockRequestSender.Factory { + private final Sender sender; + + public Factory(ServiceComponents serviceComponents, ClusterService clusterService) { + super(serviceComponents, clusterService); + this.sender = new AmazonBedrockMockRequestSender(); + } + + public Sender createSender() { + return sender; + } + } + + private Queue results = new ConcurrentLinkedQueue<>(); + private Queue> inputs = new ConcurrentLinkedQueue<>(); + private int sendCounter = 0; + + public void enqueue(Object result) { + results.add(result); + } + + public int sendCount() { + return sendCounter; + } + + public List getInputs() { + return inputs.remove(); + } + + @Override + public void start() { + // do nothing + } + + @Override + public void send( + RequestManager requestCreator, + InferenceInputs inferenceInputs, + TimeValue timeout, + ActionListener listener + ) { + sendCounter++; + var docsInput = (DocumentsOnlyInput) inferenceInputs; + inputs.add(docsInput.getInputs()); + + if (results.isEmpty()) { + listener.onFailure(new ElasticsearchException("No results found")); + } else { + var resultObject = results.remove(); + if (resultObject instanceof InferenceServiceResults inferenceResult) { + listener.onResponse(inferenceResult); + } else if (resultObject instanceof Exception e) { + listener.onFailure(e); + } else { + throw new RuntimeException("Unknown result type: " + resultObject.getClass()); + } + } + } + + @Override + public void close() throws IOException { + // do nothing + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java new file mode 100644 index 0000000000000..7fa8a09d5bf12 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AmazonBedrockRequestSenderTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + private final AtomicReference threadRef = new AtomicReference<>(); + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + threadRef.set(null); + } + + @After + public void shutdown() throws IOException, InterruptedException { + if (threadRef.get() != null) { + threadRef.get().join(TIMEOUT.millis()); + } + + terminate(threadPool); + } + + public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception { + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY); + var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); + requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT)); + try (var sender = createSender(senderFactory, requestSender)) { + sender.start(); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var requestManager = new AmazonBedrockEmbeddingsRequestManager( + model, + serviceComponents.truncator(), + threadPool, + new TimeValue(30, TimeUnit.SECONDS) + ); + sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); + } + } + + public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception { + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY); + var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); + requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text")); + try (var sender = createSender(senderFactory, requestSender)) { + sender.start(); + + var model = AmazonBedrockChatCompletionModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + var requestManager = new AmazonBedrockChatCompletionRequestManager(model, threadPool, new TimeValue(30, TimeUnit.SECONDS)); + sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test response text")))); + } + } + + public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) { + return new AmazonBedrockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, settings), + mockClusterServiceEmpty() + ); + } + + public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) { + return factory.createSender(requestSender); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..b91aab5410048 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockAI21LabsCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..89d5fec7efba6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockAnthropicCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopK() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..8df5c7f32e529 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockCohereCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopK() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java new file mode 100644 index 0000000000000..cbbe3c5554967 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ContentBlock; +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.Message; + +import org.elasticsearch.core.Strings; + +public final class AmazonBedrockConverseRequestUtils { + public static ConverseRequest getConverseRequest(String modelId, AmazonBedrockConverseRequestEntity requestEntity) { + var converseRequest = new ConverseRequest().withModelId(modelId); + converseRequest = requestEntity.addMessages(converseRequest); + converseRequest = requestEntity.addInferenceConfig(converseRequest); + converseRequest = requestEntity.addAdditionalModelFields(converseRequest); + return converseRequest; + } + + public static boolean doesConverseRequestHasMessage(ConverseRequest converseRequest, String expectedMessage) { + for (Message message : converseRequest.getMessages()) { + var content = message.getContent(); + for (ContentBlock contentBlock : content) { + if (contentBlock.getText().equals(expectedMessage)) { + return true; + } + } + } + return false; + } + + public static boolean doesConverseRequestHaveAnyTemperatureInput(ConverseRequest converseRequest) { + return converseRequest.getInferenceConfig() != null + && converseRequest.getInferenceConfig().getTemperature() != null + && (converseRequest.getInferenceConfig().getTemperature().isNaN() == false); + } + + public static boolean doesConverseRequestHaveAnyTopPInput(ConverseRequest converseRequest) { + return converseRequest.getInferenceConfig() != null + && converseRequest.getInferenceConfig().getTopP() != null + && (converseRequest.getInferenceConfig().getTopP().isNaN() == false); + } + + public static boolean doesConverseRequestHaveAnyMaxTokensInput(ConverseRequest converseRequest) { + return converseRequest.getInferenceConfig() != null && converseRequest.getInferenceConfig().getMaxTokens() != null; + } + + public static boolean doesConverseRequestHaveTemperatureInput(ConverseRequest converseRequest, Double temperature) { + return doesConverseRequestHaveAnyTemperatureInput(converseRequest) + && converseRequest.getInferenceConfig().getTemperature().equals(temperature.floatValue()); + } + + public static boolean doesConverseRequestHaveTopPInput(ConverseRequest converseRequest, Double topP) { + return doesConverseRequestHaveAnyTopPInput(converseRequest) + && converseRequest.getInferenceConfig().getTopP().equals(topP.floatValue()); + } + + public static boolean doesConverseRequestHaveMaxTokensInput(ConverseRequest converseRequest, Integer maxTokens) { + return doesConverseRequestHaveAnyMaxTokensInput(converseRequest) + && converseRequest.getInferenceConfig().getMaxTokens().equals(maxTokens); + } + + public static boolean doesConverseRequestHaveAnyTopKInput(ConverseRequest converseRequest) { + if (converseRequest.getAdditionalModelResponseFieldPaths() == null) { + return false; + } + + for (String fieldPath : converseRequest.getAdditionalModelResponseFieldPaths()) { + if (fieldPath.contains("{\"top_k\":")) { + return true; + } + } + return false; + } + + public static boolean doesConverseRequestHaveTopKInput(ConverseRequest converseRequest, Double topK) { + if (doesConverseRequestHaveAnyTopKInput(converseRequest) == false) { + return false; + } + + var checkString = Strings.format("{\"top_k\":%f}", topK.floatValue()); + for (String fieldPath : converseRequest.getAdditionalModelResponseFieldPaths()) { + if (fieldPath.contains(checkString)) { + return true; + } + } + return false; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..fa482669a0bb2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockMetaCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..788625d3702b8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockMistralCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopK() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..79fa387876c8b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockTitanCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..fd8114f889d6a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase { + public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input")); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..da98fa251fdc8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; + +import java.io.IOException; + +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockTitanEmbeddingsRequestEntityTests extends ESTestCase { + public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { + var entity = new AmazonBedrockTitanEmbeddingsRequestEntity("test input"); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"inputText\":\"test input\"}")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java new file mode 100644 index 0000000000000..904851842a6c8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.SECRET_KEY_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockSecretSettingsTests extends AbstractBWCWireSerializationTestCase { + + public void testIt_CreatesSettings_ReturnsNullFromMap_null() { + var secrets = AmazonBedrockSecretSettings.fromMap(null); + assertNull(secrets); + } + + public void testIt_CreatesSettings_FromMap_WithValues() { + var secrets = AmazonBedrockSecretSettings.fromMap( + new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest", SECRET_KEY_FIELD, "secrettest")) + ); + assertThat( + secrets, + is(new AmazonBedrockSecretSettings(new SecureString("accesstest".toCharArray()), new SecureString("secrettest".toCharArray()))) + ); + } + + public void testIt_CreatesSettings_FromMap_IgnoresExtraKeys() { + var secrets = AmazonBedrockSecretSettings.fromMap( + new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest", SECRET_KEY_FIELD, "secrettest", "extrakey", "extravalue")) + ); + assertThat( + secrets, + is(new AmazonBedrockSecretSettings(new SecureString("accesstest".toCharArray()), new SecureString("secrettest".toCharArray()))) + ); + } + + public void testIt_FromMap_ThrowsValidationException_AccessKeyMissing() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockSecretSettings.fromMap(new HashMap<>(Map.of(SECRET_KEY_FIELD, "secrettest"))) + ); + + assertThat( + thrownException.getMessage(), + containsString(Strings.format("[secret_settings] does not contain the required setting [%s]", ACCESS_KEY_FIELD)) + ); + } + + public void testIt_FromMap_ThrowsValidationException_SecretKeyMissing() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockSecretSettings.fromMap(new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest"))) + ); + + assertThat( + thrownException.getMessage(), + containsString(Strings.format("[secret_settings] does not contain the required setting [%s]", SECRET_KEY_FIELD)) + ); + } + + public void testToXContent_CreatesProperContent() throws IOException { + var secrets = AmazonBedrockSecretSettings.fromMap( + new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest", SECRET_KEY_FIELD, "secrettest")) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + secrets.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, CoreMatchers.is(""" + {"access_key":"accesstest","secret_key":"secrettest"}""")); + } + + public static Map getAmazonBedrockSecretSettingsMap(String accessKey, String secretKey) { + return new HashMap(Map.of(ACCESS_KEY_FIELD, accessKey, SECRET_KEY_FIELD, secretKey)); + } + + @Override + protected AmazonBedrockSecretSettings mutateInstanceForVersion(AmazonBedrockSecretSettings instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockSecretSettings::new; + } + + @Override + protected AmazonBedrockSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockSecretSettings mutateInstance(AmazonBedrockSecretSettings instance) throws IOException { + return randomValueOtherThan(instance, AmazonBedrockSecretSettingsTests::createRandom); + } + + private static AmazonBedrockSecretSettings createRandom() { + return new AmazonBedrockSecretSettings(new SecureString(randomAlphaOfLength(10)), new SecureString(randomAlphaOfLength(10))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java new file mode 100644 index 0000000000000..00a840c8d4812 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -0,0 +1,1131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettingsTests.createChatCompletionRequestSettingsMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AmazonBedrockServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + public void shutdown() throws IOException { + terminate(threadPool); + } + + public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "anthropic", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCreateModel_TopKParameter_NotAvailable() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + createChatCompletionRequestSettingsMap("region", "model", "amazontitan"), + getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAmazonBedrockService()) { + var config = getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ); + + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createAmazonBedrockService()) { + var serviceSettings = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret")); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); + + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); + + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_MovesModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", 512, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddingsModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "amazontitan"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, Map.of(), secretSettingsMap); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.secrets().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + settingsMap.put("extra_key", "value"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); + + var settings = (AmazonBedrockChatCompletionServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.ANTHROPIC)); + var taskSettings = (AmazonBedrockChatCompletionTaskSettings) model.getTaskSettings(); + assertThat(taskSettings.temperature(), is(1.0)); + assertThat(taskSettings.topP(), is(0.5)); + assertThat(taskSettings.topK(), is(0.2)); + assertThat(taskSettings.maxNewTokens(), is(128)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnAmazonBedrockChatCompletionModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); + + var settings = (AmazonBedrockChatCompletionServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.ANTHROPIC)); + var taskSettings = (AmazonBedrockChatCompletionTaskSettings) model.getTaskSettings(); + assertThat(taskSettings.temperature(), is(1.0)); + assertThat(taskSettings.topP(), is(0.5)); + assertThat(taskSettings.topK(), is(0.2)); + assertThat(taskSettings.maxNewTokens(), is(128)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + settingsMap.put("extra_key", "value"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + taskSettingsMap.put("extra_key", "value"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); + + var settings = (AmazonBedrockChatCompletionServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.ANTHROPIC)); + var taskSettings = (AmazonBedrockChatCompletionTaskSettings) model.getTaskSettings(); + assertThat(taskSettings.temperature(), is(1.0)); + assertThat(taskSettings.topP(), is(0.5)); + assertThat(taskSettings.topK(), is(0.2)); + assertThat(taskSettings.maxNewTokens(), is(128)); + assertNull(model.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F })))); + } + } + } + + public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); + requestSender.enqueue(mockResults); + + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectationCompletion(List.of("test result")))); + } + } + } + + public void testCheckModelConfig_IncludesMaxTokens_ForEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + 100, + null, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 2, + false, + 100, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ) + ) + ); + var inputStrings = requestSender.getInputs(); + + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testCheckModelConfig_HasSimilarity_ForEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 2, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ) + ) + ); + var inputStrings = requestSender.getInputs(); + + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 3, + true, + null, + null, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is( + "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " + + "Please recreate the [id] configuration with the correct dimensions" + ) + ); + + var inputStrings = requestSender.getInputs(); + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 100, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 2, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ) + ) + ); + var inputStrings = requestSender.getInputs(); + + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testInfer_UnauthorizedResponse() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "us-east-1", + "amazon.titan-embed-text-v1", + AmazonBedrockProvider.AMAZONTITAN, + "_INVALID_AWS_ACCESS_KEY_", + "_INVALID_AWS_SECRET_KEY_" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exceptionThrown.getCause().getMessage(), containsString("The security token included in the request is invalid")); + } + } + + public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var mockResults = new InferenceTextEmbeddingFloatResults( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.456F, 0.987F }) + ) + ); + requestSender.enqueue(mockResults); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + List.of("abc", "xyz"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("abc", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.123F, 0.678F }, floatResult.chunks().get(0).embedding(), 0.0f); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("xyz", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.456F, 0.987F }, floatResult.chunks().get(0).embedding(), 0.0f); + } + } + } + } + + private AmazonBedrockService createAmazonBedrockService() { + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private Utils.PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new Utils.PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java new file mode 100644 index 0000000000000..22173943ff432 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java @@ -0,0 +1,221 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class AmazonBedrockChatCompletionModelTests extends ESTestCase { + public void testOverrideWith_OverridesWithoutValues() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 1.0, + 0.5, + 0.6, + 512, + null, + "access_key", + "secret_key" + ); + var requestTaskSettingsMap = getChatCompletionTaskSettingsMap(null, null, null, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettingsMap); + + assertThat(overriddenModel, sameInstance(overriddenModel)); + } + + public void testOverrideWith_temperature() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 1.0, + null, + null, + null, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(0.5, null, null, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 0.5, + null, + null, + null, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public void testOverrideWith_topP() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + 0.8, + null, + null, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(null, 0.5, null, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + 0.5, + null, + null, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public void testOverrideWith_topK() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + 1.0, + null, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(null, null, 0.8, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + 0.8, + null, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public void testOverrideWith_maxNewTokens() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + 512, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(null, null, null, 128); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + 128, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public static AmazonBedrockChatCompletionModel createModel( + String id, + String region, + String model, + AmazonBedrockProvider provider, + String accessKey, + String secretKey + ) { + return createModel(id, region, model, provider, null, null, null, null, null, accessKey, secretKey); + } + + public static AmazonBedrockChatCompletionModel createModel( + String id, + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens, + @Nullable RateLimitSettings rateLimitSettings, + String accessKey, + String secretKey + ) { + return new AmazonBedrockChatCompletionModel( + id, + TaskType.COMPLETION, + "amazonbedrock", + new AmazonBedrockChatCompletionServiceSettings(region, model, provider, rateLimitSettings), + new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens), + new AmazonBedrockSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..681088c786b6b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.MatcherAssert; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockChatCompletionRequestTaskSettingsTests extends ESTestCase { + public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); + assertThat(settings, is(AmazonBedrockChatCompletionRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); + assertThat(settings, is(AmazonBedrockChatCompletionRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_ReturnsTemperature() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TEMPERATURE_FIELD, 0.1))); + assertThat(settings.temperature(), is(0.1)); + } + + public void testFromMap_ReturnsTopP() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_P_FIELD, 0.1))); + assertThat(settings.topP(), is(0.1)); + } + + public void testFromMap_ReturnsDoSample() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_K_FIELD, 0.3))); + assertThat(settings.topK(), is(0.3)); + } + + public void testFromMap_ReturnsMaxNewTokens() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(MAX_NEW_TOKENS_FIELD, 512))); + assertThat(settings.maxNewTokens(), is(512)); + } + + public void testFromMap_TemperatureIsInvalidValue_ThrowsValidationException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TEMPERATURE_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [temperature] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopPIsInvalidValue_ThrowsValidationException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_P_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [top_p] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopKIsInvalidValue_ThrowsValidationException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_K_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ); + } + + public void testFromMap_MaxTokensIsInvalidValue_ThrowsStatusException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(MAX_NEW_TOKENS_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString("field [max_new_tokens] is not of the expected type. The value [invalid] cannot be converted to a [Integer]") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..90868530d8df8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + AmazonBedrockChatCompletionServiceSettings> { + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var serviceSettings = AmazonBedrockChatCompletionServiceSettings.fromMap( + createChatCompletionRequestSettingsMap(region, model, provider), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is(new AmazonBedrockChatCompletionServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null)) + ); + } + + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var settingsMap = createChatCompletionRequestSettingsMap(region, model, provider); + settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))); + + var serviceSettings = AmazonBedrockChatCompletionServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is(new AmazonBedrockChatCompletionServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, new RateLimitSettings(3))) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var settingsMap = createChatCompletionRequestSettingsMap(region, model, provider); + var serviceSettings = AmazonBedrockChatCompletionServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is(new AmazonBedrockChatCompletionServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null)) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new AmazonBedrockChatCompletionServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":3}}""")); + } + + public static HashMap createChatCompletionRequestSettingsMap(String region, String model, String provider) { + return new HashMap(Map.of(REGION_FIELD, region, MODEL_FIELD, model, PROVIDER_FIELD, provider)); + } + + @Override + protected AmazonBedrockChatCompletionServiceSettings mutateInstanceForVersion( + AmazonBedrockChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockChatCompletionServiceSettings::new; + } + + @Override + protected AmazonBedrockChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockChatCompletionServiceSettings mutateInstance(AmazonBedrockChatCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, AmazonBedrockChatCompletionServiceSettingsTests::createRandom); + } + + private static AmazonBedrockChatCompletionServiceSettings createRandom() { + return new AmazonBedrockChatCompletionServiceSettings( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomFrom(AmazonBedrockProvider.values()), + RateLimitSettingsTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java new file mode 100644 index 0000000000000..0d5440c6d2cf8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockChatCompletionTaskSettingsTests extends AbstractBWCWireSerializationTestCase< + AmazonBedrockChatCompletionTaskSettings> { + + public void testFromMap_AllValues() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + assertEquals( + new AmazonBedrockChatCompletionTaskSettings(1.0, 0.5, 0.6, 512), + AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap) + ); + } + + public void testFromMap_TemperatureIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(TEMPERATURE_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [temperature] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopPIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(TOP_P_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [top_p] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopKIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(TOP_K_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ); + } + + public void testFromMap_MaxNewTokensIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(MAX_NEW_TOKENS_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [max_new_tokens] is not of the expected type. The value [invalid] cannot be converted to a [Integer]") + ) + ); + } + + public void testFromMap_WithNoValues_DoesNotThrowException() { + var taskMap = AmazonBedrockChatCompletionTaskSettings.fromMap(new HashMap(Map.of())); + assertNull(taskMap.temperature()); + assertNull(taskMap.topP()); + assertNull(taskMap.topK()); + assertNull(taskMap.maxNewTokens()); + } + + public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, AmazonBedrockChatCompletionTaskSettings.EMPTY_SETTINGS); + MatcherAssert.assertThat(overrideSettings, is(settings)); + } + + public void testOverrideWith_UsesTemperatureOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(0.3, null, null, null) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(0.3, 0.5, 0.6, 512))); + } + + public void testOverrideWith_UsesTopPOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(null, 0.2, null, null) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(1.0, 0.2, 0.6, 512))); + } + + public void testOverrideWith_UsesDoSampleOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(null, null, 0.1, null) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(1.0, 0.5, 0.1, 512))); + } + + public void testOverrideWith_UsesMaxNewTokensOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(null, null, null, 128) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(1.0, 0.5, 0.6, 128))); + } + + public void testToXContent_WithoutParameters() throws IOException { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(null, null, null, null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is("{}")); + } + + public void testToXContent_WithParameters() throws IOException { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"temperature":1.0,"top_p":0.5,"top_k":0.6,"max_new_tokens":512}""")); + } + + public static Map getChatCompletionTaskSettingsMap( + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens + ) { + var map = new HashMap(); + + if (temperature != null) { + map.put(TEMPERATURE_FIELD, temperature); + } + + if (topP != null) { + map.put(TOP_P_FIELD, topP); + } + + if (topK != null) { + map.put(TOP_K_FIELD, topK); + } + + if (maxNewTokens != null) { + map.put(MAX_NEW_TOKENS_FIELD, maxNewTokens); + } + + return map; + } + + @Override + protected AmazonBedrockChatCompletionTaskSettings mutateInstanceForVersion( + AmazonBedrockChatCompletionTaskSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockChatCompletionTaskSettings::new; + } + + @Override + protected AmazonBedrockChatCompletionTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockChatCompletionTaskSettings mutateInstance(AmazonBedrockChatCompletionTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, AmazonBedrockChatCompletionTaskSettingsTests::createRandom); + } + + private static AmazonBedrockChatCompletionTaskSettings createRandom() { + return new AmazonBedrockChatCompletionTaskSettings( + randomFrom(new Double[] { null, randomDouble() }), + randomFrom(new Double[] { null, randomDouble() }), + randomFrom(new Double[] { null, randomDouble() }), + randomFrom(new Integer[] { null, randomNonNegativeInt() }) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java new file mode 100644 index 0000000000000..711e3cbb5a511 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { + + public void testCreateModel_withTaskSettings_shouldFail() { + var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey"); + var thrownException = assertThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue")) + ); + assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings")); + } + + // model creation only - no tests to define, but we want to have the public createModel + // method available + + public static AmazonBedrockEmbeddingsModel createModel( + String inferenceId, + String region, + String model, + AmazonBedrockProvider provider, + String accessKey, + String secretKey + ) { + return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); + } + + public static AmazonBedrockEmbeddingsModel createModel( + String inferenceId, + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings, + String accessKey, + String secretKey + ) { + return new AmazonBedrockEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "amazonbedrock", + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + provider, + dimensions, + dimensionsSetByUser, + maxTokens, + similarity, + rateLimitSettings + ), + new EmptyTaskSettings(), + new AmazonBedrockSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..a100b89e1db6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java @@ -0,0 +1,404 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + AmazonBedrockEmbeddingsServiceSettings> { + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap( + createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null + ) + ) + ); + } + + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE); + settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))); + + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + new RateLimitSettings(3) + ) + ) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null + ) + ) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, true, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] does not allow the setting [%s];", DIMENSIONS_SET_BY_USER) + ) + ); + } + + public void testFromMap_Request_Dimensions_ShouldThrowWhenPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var dims = 128; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, dims, null, null, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString(Strings.format("[service_settings] does not allow the setting [%s]", DIMENSIONS)) + ); + } + + public void testFromMap_Request_MaxTokensShouldBePositiveInteger() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = -128; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString(Strings.format("[%s] must be a positive integer", MAX_INPUT_TOKENS)) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var dims = 1536; + var maxInputTokens = 512; + + var settingsMap = createEmbeddingsRequestSettingsMap( + region, + model, + provider, + dims, + false, + maxInputTokens, + SimilarityMeasure.COSINE + ); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + dims, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null + ) + ) + ); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, true, null, null); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is(new AmazonBedrockEmbeddingsServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null, true, null, null, null)) + ); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, true, null, SimilarityMeasure.DOT_PRODUCT); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + true, + null, + SimilarityMeasure.DOT_PRODUCT, + null + ) + ) + ); + } + + public void testFromMap_PersistentContext_ThrowsException_WhenDimensionsSetByUserIsNull() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, 1, null, null, null); + + var exception = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT) + ); + + assertThat( + exception.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];") + ); + } + + public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException { + var entity = new AmazonBedrockEmbeddingsServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + null, + true, + null, + null, + new RateLimitSettings(2) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":2},"dimensions_set_by_user":true}""")); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new AmazonBedrockEmbeddingsServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + 1024, + false, + 512, + null, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512,"dimensions_set_by_user":false}""")); + } + + public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() throws IOException { + var entity = new AmazonBedrockEmbeddingsServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + 1024, + false, + 512, + null, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}""")); + } + + public static HashMap createEmbeddingsRequestSettingsMap( + String region, + String model, + String provider, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarityMeasure + ) { + var map = new HashMap(Map.of(REGION_FIELD, region, MODEL_FIELD, model, PROVIDER_FIELD, provider)); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + } + + if (dimensionsSetByUser != null) { + map.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser.equals(Boolean.TRUE)); + } + + if (maxTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxTokens); + } + + if (similarityMeasure != null) { + map.put(SIMILARITY, similarityMeasure.toString()); + } + + return map; + } + + @Override + protected AmazonBedrockEmbeddingsServiceSettings mutateInstanceForVersion( + AmazonBedrockEmbeddingsServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockEmbeddingsServiceSettings::new; + } + + @Override + protected AmazonBedrockEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockEmbeddingsServiceSettings mutateInstance(AmazonBedrockEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, AmazonBedrockEmbeddingsServiceSettingsTests::createRandom); + } + + private static AmazonBedrockEmbeddingsServiceSettings createRandom() { + return new AmazonBedrockEmbeddingsServiceSettings( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomFrom(AmazonBedrockProvider.values()), + randomFrom(new Integer[] { null, randomNonNegativeInt() }), + randomBoolean(), + randomFrom(new Integer[] { null, randomNonNegativeInt() }), + randomFrom(new SimilarityMeasure[] { null, randomFrom(SimilarityMeasure.values()) }), + RateLimitSettingsTests.createRandom() + ); + } +}