diff --git a/docs/changelog/113981.yaml b/docs/changelog/113981.yaml new file mode 100644 index 0000000000000..38f3a6f04ae46 --- /dev/null +++ b/docs/changelog/113981.yaml @@ -0,0 +1,6 @@ +pr: 113981 +summary: "Adding chunking settings to `GoogleVertexAiService,` `AzureAiStudioService,`\ + \ and `AlibabaCloudSearchService`" +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 0bd0eee1aa9a1..c5c88ad978d63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -24,6 +25,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -74,11 +77,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + AlibabaCloudSearchModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -99,6 +110,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage ) { @@ -107,6 +119,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -118,6 +131,7 @@ private static AlibabaCloudSearchModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -129,6 +143,7 @@ private static AlibabaCloudSearchModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -138,6 +153,7 @@ private static AlibabaCloudSearchModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -174,11 +190,17 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelWithoutLoggingDeprecations( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -189,11 +211,17 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelWithoutLoggingDeprecations( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -238,17 +266,36 @@ protected void doChunkedInfer( AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()), + alibabaCloudSearchModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()) + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); } } + private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskType taskType) { + return switch (taskType) { + case TEXT_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.FLOAT; + case SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE; + default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType); + }; + } + /** * For text embedding models get the embedding size and * update the service settings. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java index 87e5e59ae3434..2654ee4d22ce6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -39,6 +40,7 @@ public AlibabaCloudSearchEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -48,6 +50,7 @@ public AlibabaCloudSearchEmbeddingsModel( service, AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context), AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -59,10 +62,11 @@ public AlibabaCloudSearchEmbeddingsModel( String service, AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings, AlibabaCloudSearchEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), serviceSettings.getCommonSettings() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java index 76dfd01f333da..8896e983d3e7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -81,10 +82,21 @@ public SimilarityMeasure getSimilarity() { return similarity; } - public Integer getDimensions() { + @Override + public Integer dimensions() { return dimensions; } + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + public Integer getMaxInputTokens() { return maxInputTokens; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java index b551ba389136b..0155d8fbc1f08 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -39,6 +40,7 @@ public AlibabaCloudSearchSparseModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -48,6 +50,7 @@ public AlibabaCloudSearchSparseModel( service, AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context), AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } @@ -59,10 +62,11 @@ public AlibabaCloudSearchSparseModel( String service, AlibabaCloudSearchSparseServiceSettings serviceSettings, AlibabaCloudSearchSparseTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), serviceSettings.getCommonSettings() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 7981fb393a842..c1ca50d41268e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -24,6 +25,8 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -90,11 +93,23 @@ protected void doChunkedInfer( ) { if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) { var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + baseAzureAiStudioModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); @@ -115,11 +130,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + AzureAiStudioModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -146,11 +169,17 @@ public AzureAiStudioModel parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -161,11 +190,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -186,6 +221,7 @@ private static AzureAiStudioModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -198,6 +234,7 @@ private static AzureAiStudioModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); @@ -235,6 +272,7 @@ private AzureAiStudioModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -243,6 +281,7 @@ private AzureAiStudioModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java index a999b9f0312e6..edbefe07cff02 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -44,9 +45,13 @@ public AzureAiStudioEmbeddingsModel( String service, AzureAiStudioEmbeddingsServiceSettings serviceSettings, AzureAiStudioEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, DefaultSecretSettings secrets ) { - super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secrets) + ); } public AzureAiStudioEmbeddingsModel( @@ -55,6 +60,7 @@ public AzureAiStudioEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secrets, ConfigurationParseContext context ) { @@ -64,6 +70,7 @@ public AzureAiStudioEmbeddingsModel( service, AzureAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context), AzureAiStudioEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, DefaultSecretSettings.fromMap(secrets) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index d9d8850048564..ae9219ba38499 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -23,6 +24,8 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -70,11 +73,19 @@ public void parseRequestConfig( Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + GoogleVertexAiModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), ConfigurationParseContext.REQUEST @@ -101,11 +112,17 @@ public Model parsePersistedConfigWithSecrets( Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -116,11 +133,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -179,11 +202,22 @@ protected void doChunkedInfer( GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model; var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); - var batchedRequests = new EmbeddingRequestChunker( - inputs.getInputs(), - EMBEDDING_MAX_BATCH_SIZE, - EmbeddingRequestChunker.EmbeddingType.FLOAT - ).batchRequestsWithListeners(listener); + List batchedRequests; + if (ChunkingSettingsFeatureFlag.isEnabled()) { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT, + googleVertexAiModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + } else { + batchedRequests = new EmbeddingRequestChunker( + inputs.getInputs(), + EMBEDDING_MAX_BATCH_SIZE, + EmbeddingRequestChunker.EmbeddingType.FLOAT + ).batchRequestsWithListeners(listener); + } + for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); @@ -225,6 +259,7 @@ private static GoogleVertexAiModel createModelFromPersistent( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secretSettings, String failureMessage ) { @@ -233,6 +268,7 @@ private static GoogleVertexAiModel createModelFromPersistent( taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT @@ -244,6 +280,7 @@ private static GoogleVertexAiModel createModel( TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage, ConfigurationParseContext context @@ -255,6 +292,7 @@ private static GoogleVertexAiModel createModel( NAME, serviceSettings, taskSettings, + chunkingSettings, secretSettings, context ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java index 99110045fc3da..3a5fae09b40ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java @@ -9,6 +9,7 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskType; @@ -35,6 +36,7 @@ public GoogleVertexAiEmbeddingsModel( String service, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, Map secrets, ConfigurationParseContext context ) { @@ -44,6 +46,7 @@ public GoogleVertexAiEmbeddingsModel( service, GoogleVertexAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettings), + chunkingSettings, GoogleVertexAiSecretSettings.fromMap(secrets) ); } @@ -59,10 +62,11 @@ public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, Google String service, GoogleVertexAiEmbeddingsServiceSettings serviceSettings, GoogleVertexAiEmbeddingsTaskSettings taskSettings, + ChunkingSettings chunkingSettings, @Nullable GoogleVertexAiSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets), serviceSettings ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index e8c34eec96171..7cedc36ffa5f0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -7,12 +7,14 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -20,9 +22,12 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -31,28 +36,34 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; -import org.hamcrest.CoreMatchers; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; @@ -99,6 +110,233 @@ public void testParseRequestConfig_CreatesAnEmbeddingsModel() throws IOException } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, e -> fail("Model parsing should have succeeded " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var model = service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ).config() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var model = service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap() + ).config() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var model = service.parsePersistedConfig( + "id", + TaskType.TEXT_EMBEDDING, + getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null) + ).config() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var persistedConfig = getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var persistedConfig = getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = new AlibabaCloudSearchService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool))) { + var persistedConfig = getPersistedConfigMap( + AlibabaCloudSearchEmbeddingsServiceSettingsTests.getServiceSettingsMap("service_id", "host", "default"), + AlibabaCloudSearchEmbeddingsTaskSettingsTests.getTaskSettingsMap(null), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AlibabaCloudSearchEmbeddingsModel.class)); + var embeddingsModel = (AlibabaCloudSearchEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("service_id")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getHost(), is("host")); + assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getWorkspaceName(), is("default")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testCheckModelConfig() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -165,49 +403,71 @@ public void doInfer( } } - public void testChunkedInfer_Batches() throws IOException { - var input = List.of("foo", "bar"); + public void testChunkedInfer_TextEmbeddingBatches() throws IOException { + testChunkedInfer(TaskType.TEXT_EMBEDDING, null); + } + public void testChunkedInfer_TextEmbeddingChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings()); + } + + public void testChunkedInfer_TextEmbeddingChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.TEXT_EMBEDDING, null); + } + + public void testChunkedInfer_SparseEmbeddingBatches() throws IOException { + testChunkedInfer(TaskType.SPARSE_EMBEDDING, null); + } + + public void testChunkedInfer_SparseEmbeddingChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.SPARSE_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings()); + } + + public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + testChunkedInfer(TaskType.SPARSE_EMBEDDING, null); + } + + public void testChunkedInfer_InvalidTaskType() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { - Map serviceSettingsMap = new HashMap<>(); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); - serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + var model = AlibabaCloudSearchCompletionModelTests.createModel( + randomAlphaOfLength(10), + TaskType.COMPLETION, + AlibabaCloudSearchCompletionServiceSettingsTests.createRandom(), + AlibabaCloudSearchCompletionTaskSettingsTests.createRandom(), + null + ); - Map taskSettingsMap = new HashMap<>(); + PlainActionFuture> listener = new PlainActionFuture<>(); + try { + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + } catch (Exception e) { + assertThat(e, instanceOf(IllegalArgumentException.class)); + } + } + } - Map secretSettingsMap = new HashMap<>(); - secretSettingsMap.put("api_key", "secret"); + private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { + var input = List.of("foo", "bar"); - var model = new AlibabaCloudSearchEmbeddingsModel( - "service", - TaskType.TEXT_EMBEDDING, - AlibabaCloudSearchUtils.SERVICE_NAME, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap, - null - ) { - public ExecutableAction accept( - AlibabaCloudSearchActionVisitor visitor, - Map taskSettings, - InputType inputType - ) { - return (inferenceInputs, timeout, listener) -> { - InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( - List.of( - new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), - new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) - ) - ); - - listener.onResponse(results); - }; - } - }; + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + var model = createModelForTaskType(taskType, chunkingSettings); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( @@ -222,26 +482,101 @@ public ExecutableAction accept( ); var results = listener.actionGet(TIMEOUT); + assertThat(results, instanceOf(List.class)); assertThat(results, hasSize(2)); + var firstResult = results.get(0); + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + assertThat(firstResult, instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + assertThat(firstResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + } + } + } + + private AlibabaCloudSearchModel createModelForTaskType(TaskType taskType, ChunkingSettings chunkingSettings) { + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + + secretSettingsMap.put("api_key", "secret"); + return switch (taskType) { + case TEXT_EMBEDDING -> createEmbeddingsModel(serviceSettingsMap, taskSettingsMap, chunkingSettings, secretSettingsMap); + case SPARSE_EMBEDDING -> createSparseEmbeddingsModel(serviceSettingsMap, taskSettingsMap, chunkingSettings, secretSettingsMap); + default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType); + }; + } - // first result - { - assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); + private AlibabaCloudSearchModel createEmbeddingsModel( + Map serviceSettingsMap, + Map taskSettingsMap, + ChunkingSettings chunkingSettings, + Map secretSettingsMap + ) { + return new AlibabaCloudSearchEmbeddingsModel( + "service", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + null + ) { + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return (inferenceInputs, timeout, listener) -> { + InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) + ) + ); + + listener.onResponse(results); + }; } + }; + } - // second result - { - assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); + private AlibabaCloudSearchModel createSparseEmbeddingsModel( + Map serviceSettingsMap, + Map taskSettingsMap, + ChunkingSettings chunkingSettings, + Map secretSettingsMap + ) { + return new AlibabaCloudSearchSparseModel( + "service", + TaskType.SPARSE_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettingsMap, + taskSettingsMap, + chunkingSettings, + secretSettingsMap, + null + ) { + public ExecutableAction accept(AlibabaCloudSearchActionVisitor visitor, Map taskSettings, InputType inputType) { + return (inferenceInputs, timeout, listener) -> { + listener.onResponse(SparseEmbeddingResultsTests.createRandomResults(2, 1)); + }; } - } + }; + } + + public Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; } private Map getRequestConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java index fca0ee11e5c78..957b7149b14f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsModelTests.java @@ -47,6 +47,7 @@ public static AlibabaCloudSearchEmbeddingsModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secrets, null ); @@ -65,6 +66,7 @@ public static AlibabaCloudSearchEmbeddingsModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secretSettings ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java index 4e9179b66c36f..4a89e1fc924a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseModelTests.java @@ -47,6 +47,7 @@ public static AlibabaCloudSearchSparseModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secrets, null ); @@ -65,6 +66,7 @@ public static AlibabaCloudSearchSparseModel createModel( AlibabaCloudSearchUtils.SERVICE_NAME, serviceSettings, taskSettings, + null, secretSettings ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 1df457b3211ea..683f32710bcb3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -29,6 +30,7 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; @@ -62,6 +64,8 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; @@ -124,6 +128,90 @@ public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModel() throw } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null); + + var config = getRequestConfigMap( + serviceSettings, + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + } + } + + public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null), + getEmbeddingsTaskSettingsMap("user"), + getSecretSettingsMap("secret") + ), + modelVerificationListener + ); + } + } + public void testParseRequestConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap(model -> { @@ -461,6 +549,89 @@ public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() thr } } + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( @@ -651,6 +822,84 @@ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() thro } } + public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWithoutChunkingSettingsFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + createRandomChunkingSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null), + getEmbeddingsTaskSettingsMap("user"), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (AzureAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local")); + assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI)); + assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN)); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( @@ -843,6 +1092,61 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc } public void testChunkedInfer() throws IOException { + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + "apikey", + null, + false, + null, + null, + "user", + null + ); + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + createRandomChunkingSettings(), + "apikey", + null, + false, + null, + null, + "user", + null + ); + testChunkedInfer(model); + } + + public void testChunkedInfer_ChunkingSettingsNotSetAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var model = AzureAiStudioEmbeddingsModelTests.createModel( + "id", + getUrl(webServer), + AzureAiStudioProvider.OPENAI, + AzureAiStudioEndpointType.TOKEN, + null, + "apikey", + null, + false, + null, + null, + "user", + null + ); + testChunkedInfer(model); + } + + private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { @@ -877,19 +1181,6 @@ public void testChunkedInfer() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - false, - null, - null, - "user", - null - ); PlainActionFuture> listener = new PlainActionFuture<>(); service.chunkedInfer( model, @@ -1020,6 +1311,18 @@ private AzureAiStudioService createService() { return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java index 5a450f03b4e01..c9b0f905abaa4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModelTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -104,6 +105,51 @@ public static AzureAiStudioEmbeddingsModel createModel( return createModel(inferenceId, target, provider, endpointType, apiKey, null, false, null, null, null, null); } + public static AzureAiStudioEmbeddingsModel createModel( + String inferenceId, + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + ChunkingSettings chunkingSettings, + String apiKey + ) { + return createModel(inferenceId, target, provider, endpointType, chunkingSettings, apiKey, null, false, null, null, null, null); + } + + public static AzureAiStudioEmbeddingsModel createModel( + String inferenceId, + String target, + AzureAiStudioProvider provider, + AzureAiStudioEndpointType endpointType, + ChunkingSettings chunkingSettings, + String apiKey, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + @Nullable String user, + RateLimitSettings rateLimitSettings + ) { + return new AzureAiStudioEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "azureaistudio", + new AzureAiStudioEmbeddingsServiceSettings( + target, + provider, + endpointType, + dimensions, + dimensionsSetByUser, + maxTokens, + similarity, + rateLimitSettings + ), + new AzureAiStudioEmbeddingsTaskSettings(user), + chunkingSettings, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + public static AzureAiStudioEmbeddingsModel createModel( String inferenceId, String target, @@ -132,6 +178,7 @@ public static AzureAiStudioEmbeddingsModel createModel( rateLimitSettings ), new AzureAiStudioEmbeddingsTaskSettings(user), + null, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java index 6a96d289a8190..70ec6522c0fcb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java @@ -10,12 +10,14 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -37,7 +39,9 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -107,6 +111,130 @@ public void testParseRequestConfig_CreatesGoogleVertexAiEmbeddingsModel() throws } } + public void testParseRequestConfig_ThrowsElasticsearchStatusExceptionWhenChunkingSettingsProvidedAndFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + try (var service = createGoogleVertexAiService()) { + var config = getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + "model", + GoogleVertexAiServiceFields.LOCATION, + "location", + GoogleVertexAiServiceFields.PROJECT_ID, + "project" + ) + ), + getTaskSettingsMap(true), + createRandomChunkingSettingsMap(), + getSecretSettingsMap("{}") + ); + + var failureListener = ActionListener.wrap(model -> fail("Expected exception, but got model: " + model), exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), containsString("Model configuration contains settings")); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, failureListener); + } + } + + public void testParseRequestConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + new HashMap<>(Map.of()), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(serviceAccountJson) + ), + modelListener + ); + } + } + + public void testParseRequestConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + }, e -> fail("Model parsing should succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId + ) + ), + new HashMap<>(Map.of()), + getSecretSettingsMap(serviceAccountJson) + ), + modelListener + ); + } + } + public void testParseRequestConfig_CreatesGoogleVertexAiRerankModel() throws IOException { var projectId = "project"; var serviceAccountJson = """ @@ -321,6 +449,161 @@ public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiEmbeddingsM } } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap(), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + var serviceAccountJson = """ + { + "some json" + } + """; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + getSecretSettingsMap(serviceAccountJson) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson)); + } + } + public void testParsePersistedConfigWithSecrets_CreatesGoogleVertexAiRerankModel() throws IOException { var projectId = "project"; var topN = 1; @@ -550,12 +833,142 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } + public void testParsePersistedConfig_CreatesAGoogleVertexAiEmbeddingsModelWithoutChunkingSettingsWhenFeatureFlagDisabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is disabled", ChunkingSettingsFeatureFlag.isEnabled() == false); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertNull(embeddingsModel.getConfigurations().getChunkingSettings()); + } + } + + public void testParsePersistedConfig_CreatesAGoogleVertexAiEmbeddingsModelWhenChunkingSettingsProvidedAndFeatureFlagEnabled() + throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate), + createRandomChunkingSettingsMap() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + + public void testParsePersistedConfig_CreatesAnEmbeddingsModelWhenChunkingSettingsNotProvidedAndFeatureFlagEnabled() throws IOException { + assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled()); + var projectId = "project"; + var location = "location"; + var modelId = "model"; + var autoTruncate = true; + + try (var service = createGoogleVertexAiService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + GoogleVertexAiServiceFields.LOCATION, + location, + GoogleVertexAiServiceFields.PROJECT_ID, + projectId, + GoogleVertexAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, + true + ) + ), + getTaskSettingsMap(autoTruncate) + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleVertexAiEmbeddingsModel.class)); + + var embeddingsModel = (GoogleVertexAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getServiceSettings().location(), is(location)); + assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId)); + assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE)); + assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate))); + assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); + } + } + // testInfer tested via end-to-end notebook tests in AppEx repo private GoogleVertexAiService createGoogleVertexAiService() { return new GoogleVertexAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map chunkingSettings, + Map secretSettings + ) { + var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings); + requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings); + + return requestConfigMap; + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java index ca38bdb6e2c6c..68d03d350d06e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java @@ -79,8 +79,8 @@ public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullabl null ), new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate), + null, new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray())) ); } - }