Skip to content

Commit

Permalink
[8.x] Adding chunking settings to GoogleVertexAiService, AzureAiStudi…
Browse files Browse the repository at this point in the history
…oService, and AlibabaCloudSearchService (#113981) (#114449)

* Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService (#113981)

* Adding chunking settings to GoogleVertexAiService, AzureAiStudioService, and AlibabaCloudSearchService

* Update docs/changelog/113981.yaml

* Updating AlibabaService chunkedInfer to handle sparse embedding task types

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>

* Fix enum switch case error in AlibabaSearchService (#114504)

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
dan-rubinstein and elasticmachine authored Oct 10, 2024
1 parent bb6470c commit 07d57e2
Show file tree
Hide file tree
Showing 16 changed files with 1,350 additions and 87 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/113981.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 113981
summary: "Adding chunking settings to `GoogleVertexAiService,` `AzureAiStudioService,`\
\ and `AlibabaCloudSearchService`"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand All @@ -24,6 +25,8 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
Expand Down Expand Up @@ -74,11 +77,19 @@ public void parseRequestConfig(
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}

AlibabaCloudSearchModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
Expand All @@ -99,6 +110,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage
) {
Expand All @@ -107,6 +119,7 @@ private static AlibabaCloudSearchModel createModelWithoutLoggingDeprecations(
taskType,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
Expand All @@ -118,6 +131,7 @@ private static AlibabaCloudSearchModel createModel(
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
Expand All @@ -129,6 +143,7 @@ private static AlibabaCloudSearchModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand All @@ -138,6 +153,7 @@ private static AlibabaCloudSearchModel createModel(
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
Expand Down Expand Up @@ -174,11 +190,17 @@ public AlibabaCloudSearchModel parsePersistedConfigWithSecrets(
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand All @@ -189,11 +211,17 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);

ChunkingSettings chunkingSettings = null;
if (ChunkingSettingsFeatureFlag.isEnabled() && List.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING).contains(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
}

return createModelWithoutLoggingDeprecations(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
Expand Down Expand Up @@ -238,17 +266,36 @@ protected void doChunkedInfer(
AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model;
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());

var batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
EmbeddingRequestChunker.EmbeddingType.FLOAT
).batchRequestsWithListeners(listener);
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
if (ChunkingSettingsFeatureFlag.isEnabled()) {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType()),
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
} else {
batchedRequests = new EmbeddingRequestChunker(
inputs.getInputs(),
EMBEDDING_MAX_BATCH_SIZE,
getEmbeddingTypeFromTaskType(alibabaCloudSearchModel.getTaskType())
).batchRequestsWithListeners(listener);
}

for (var request : batchedRequests) {
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
}
}

private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskType taskType) {
return switch (taskType) {
case TEXT_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.FLOAT;
case SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE;
default -> throw new IllegalArgumentException("Unsupported task type for chunking: " + taskType);
};
}

/**
* For text embedding models get the embedding size and
* update the service settings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -39,6 +40,7 @@ public AlibabaCloudSearchEmbeddingsModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -48,6 +50,7 @@ public AlibabaCloudSearchEmbeddingsModel(
service,
AlibabaCloudSearchEmbeddingsServiceSettings.fromMap(serviceSettings, context),
AlibabaCloudSearchEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -59,10 +62,11 @@ public AlibabaCloudSearchEmbeddingsModel(
String service,
AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings,
AlibabaCloudSearchEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings.getCommonSettings()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
Expand Down Expand Up @@ -81,10 +82,21 @@ public SimilarityMeasure getSimilarity() {
return similarity;
}

public Integer getDimensions() {
@Override
public Integer dimensions() {
return dimensions;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}

public Integer getMaxInputTokens() {
return maxInputTokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
Expand Down Expand Up @@ -39,6 +40,7 @@ public AlibabaCloudSearchSparseModel(
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
Expand All @@ -48,6 +50,7 @@ public AlibabaCloudSearchSparseModel(
service,
AlibabaCloudSearchSparseServiceSettings.fromMap(serviceSettings, context),
AlibabaCloudSearchSparseTaskSettings.fromMap(taskSettings),
chunkingSettings,
DefaultSecretSettings.fromMap(secrets)
);
}
Expand All @@ -59,10 +62,11 @@ public AlibabaCloudSearchSparseModel(
String service,
AlibabaCloudSearchSparseServiceSettings serviceSettings,
AlibabaCloudSearchSparseTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secretSettings),
serviceSettings.getCommonSettings()
);
Expand Down
Loading

0 comments on commit 07d57e2

Please sign in to comment.