Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi modal default preprocess function #2500

Merged
merged 15 commits into from
Jul 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.ml.common.connector.functions.preprocess.BedrockEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereEmbeddingPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand All @@ -22,6 +23,7 @@ public class MLPreProcessFunction {
public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding";
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.multimodal.embedding";
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding";
public static final String TEXT_SIMILARITY_TO_COHERE_RERANK_INPUT = "connector.pre_process.cohere.rerank";
public static final String TEXT_SIMILARITY_TO_DEFAULT_INPUT = "connector.pre_process.default.rerank";
Expand All @@ -34,7 +36,9 @@ public class MLPreProcessFunction {
OpenAIEmbeddingPreProcessFunction openAIEmbeddingPreProcessFunction = new OpenAIEmbeddingPreProcessFunction();
BedrockEmbeddingPreProcessFunction bedrockEmbeddingPreProcessFunction = new BedrockEmbeddingPreProcessFunction();
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAIEmbeddingPreProcessFunction);
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,35 @@
import org.opensearch.script.TemplateScript;

import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import static org.opensearch.ml.common.utils.StringUtils.addDefaultMethod;

/**
* This abstract class represents a pre-processing function for a connector.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link MLInput}, and the pre-processing function can be customized by implementing the {@link #validate(MLInput)} and {@link #process(MLInput)} methods.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it can be returned directly by setting the {@link #returnDirectlyForRemoteInferenceInput} flag to true.
*/
@Log4j2
public abstract class ConnectorPreProcessFunction implements Function<MLInput, RemoteInferenceInputDataSet> {

/**
* This is a flag that can be used to determine if the pre-process function should return the input directly for RemoteInferenceInputDataSet.
* If this is true and the input is already of type RemoteInferenceInputDataSet, it will be returned directly, otherwise it will be processed.
*/
protected boolean returnDirectlyForRemoteInferenceInput;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add comment for this variable too? What does return directly mean here?


/**
* Applies the pre-processing function to the given MLInput object and returns the resulting RemoteInferenceInputDataSet.
*
* @param mlInput the MLInput object to be processed
* @return the RemoteInferenceInputDataSet resulting from the pre-processing function
* @throws IllegalArgumentException if the input MLInput object is null
*/
@Override
public RemoteInferenceInputDataSet apply(MLInput mlInput) {
if (mlInput == null) {
Expand All @@ -42,9 +61,17 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) {

public abstract RemoteInferenceInputDataSet process(MLInput mlInput);

/**
* Validates the input of a pre-process function for text documents.
*
* @param mlInput the input data to be validated
* @throws IllegalArgumentException if the input dataset is not an instance of TextDocsInputDataSet
* or if there is no input text or image provided
*/
public void validateTextDocsInput(MLInput mlInput) {
if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) {
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet");
log.error(String.format(Locale.ROOT, "This pre_process_function can only support TextDocsInputDataSet, actual input type is: %s", mlInput.getInputDataset().getClass().getName()));
throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet which including a list of string with key 'text_docs'");
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;

/**
* This class provides a pre-processing function for multi-modal input data.
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}.
* The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input.
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
*/
public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction {

public MultiModalConnectorPreProcessFunction() {
this.returnDirectlyForRemoteInferenceInput = true;
}

@Override
public void validate(MLInput mlInput) {
validateTextDocsInput(mlInput);
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
throw new IllegalArgumentException("No input text or image provided");
}
}

/**
* @param mlInput The input data to be processed.
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object.
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly.
* The inputText will always show up in the first document, even it's null.
*/
@Override
public RemoteInferenceInputDataSet process(MLInput mlInput) {
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
Map<String, String> parametersMap = new HashMap<>();
parametersMap.put("inputText", inputData.getDocs().get(0));
if (inputData.getDocs().size() > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want size to be only 2, right? May be check that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine not to check this as we're using fixed index number to fetch the input docs.

parametersMap.put("inputImage", inputData.getDocs().get(1));
}
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector.functions.preprocess;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;

public class MultiModalConnectorPreProcessFunctionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

MultiModalConnectorPreProcessFunction function;

TextSimilarityInputDataSet textSimilarityInputDataSet;
TextDocsInputDataSet textDocsInputDataSet;
RemoteInferenceInputDataSet remoteInferenceInputDataSet;

MLInput textEmbeddingInput;
MLInput textSimilarityInput;
MLInput remoteInferenceInput;

@Before
public void setUp() {
function = new MultiModalConnectorPreProcessFunction();
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build();
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build();

textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
}

@Test
public void testProcess_whenNullInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Preprocess function input can't be null");
function.apply(null);
}

@Test
public void testProcess_whenWrongInput_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
function.apply(textSimilarityInput);
}

@Test
public void testProcess_whenCorrectInput_expectCorrectOutput() {
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(2, dataSet.getParameters().size());
assertEquals("hello", dataSet.getParameters().get("inputText"));
assertEquals("world", dataSet.getParameters().get("inputImage"));
}

@Test
public void testProcess_whenInputTextOnly_expectInputTextShowUp() {
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
assertEquals(1, dataSet.getParameters().size());
assertEquals("hello", dataSet.getParameters().get("inputText"));
}

@Test
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("No input text or image provided");
List<String> docs = new ArrayList<>();
docs.add(null);
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build();
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
}

@Test
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() {
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
assertEquals(remoteInferenceInputDataSet, dataSet);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.client.RestClient;
import org.opensearch.client.RestClientBuilder;
import org.opensearch.common.io.PathUtils;
Expand Down Expand Up @@ -103,6 +104,9 @@
import com.google.gson.Gson;
import com.google.gson.JsonArray;

import lombok.extern.log4j.Log4j2;

@Log4j2
public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase {
protected Gson gson = new Gson();
public static long CUSTOM_MODEL_TIMEOUT = 20_000; // 20 seconds
Expand Down Expand Up @@ -915,8 +919,14 @@ public Map predictTextEmbedding(String modelId) throws IOException {

public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException {
String requestBody = TestHelper.toJsonString(input);
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
Response response = null;
try {
response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
} catch (ResponseException e) {
log.error(e.getMessage(), e);
response = e.getResponse();
}
return parseResponseToMap(response);
}

Expand Down
Loading
Loading