-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add multi modal default preprocess function (#2500)
* Add multi modal default preprocess function Signed-off-by: zane-neo <zaniu@amazon.com> * Address comments Signed-off-by: zane-neo <zaniu@amazon.com> * address comments Signed-off-by: zane-neo <zaniu@amazon.com> * add IT Signed-off-by: zane-neo <zaniu@amazon.com> * Fix IT Signed-off-by: zane-neo <zaniu@amazon.com> * Update common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java Co-authored-by: Yaliang Wu <ylwu@amazon.com> Signed-off-by: zane-neo <zaniu@amazon.com> * fix test Signed-off-by: Yaliang Wu <ylwu@amazon.com> * Add more ITs Signed-off-by: zane-neo <zaniu@amazon.com> * Fix failure ITs Signed-off-by: zane-neo <zaniu@amazon.com> * fix failure IT Signed-off-by: zane-neo <zaniu@amazon.com> * Fix failure ITs Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * Add error response to make it esay to figure out the failure root cause Signed-off-by: zane-neo <zaniu@amazon.com> * format code Signed-off-by: zane-neo <zaniu@amazon.com> * rebase main Signed-off-by: zane-neo <zaniu@amazon.com> --------- Signed-off-by: zane-neo <zaniu@amazon.com> Signed-off-by: Yaliang Wu <ylwu@amazon.com> Co-authored-by: Yaliang Wu <ylwu@amazon.com> (cherry picked from commit 0e89c17)
- Loading branch information
1 parent
c3b0d8c
commit 1c795fc
Showing
7 changed files
with
427 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
...earch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* | ||
* * Copyright OpenSearch Contributors | ||
* * SPDX-License-Identifier: Apache-2.0 | ||
* | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString; | ||
|
||
/** | ||
* This class provides a pre-processing function for multi-modal input data. | ||
* It takes an instance of {@link MLInput} as input and returns an instance of {@link RemoteInferenceInputDataSet}. | ||
* The input data is expected to be of type {@link TextDocsInputDataSet}, with the first document representing text input and the second document representing an image input. | ||
* The function validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. | ||
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. | ||
*/ | ||
public class MultiModalConnectorPreProcessFunction extends ConnectorPreProcessFunction { | ||
|
||
public MultiModalConnectorPreProcessFunction() { | ||
this.returnDirectlyForRemoteInferenceInput = true; | ||
} | ||
|
||
@Override | ||
public void validate(MLInput mlInput) { | ||
validateTextDocsInput(mlInput); | ||
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); | ||
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) { | ||
throw new IllegalArgumentException("No input text or image provided"); | ||
} | ||
} | ||
|
||
/** | ||
* @param mlInput The input data to be processed. | ||
* This method validates the input data and then processes it to create a {@link RemoteInferenceInputDataSet} object. | ||
* If the input data is already of type {@link RemoteInferenceInputDataSet}, it is returned directly. | ||
* The inputText will always show up in the first document, even it's null. | ||
*/ | ||
@Override | ||
public RemoteInferenceInputDataSet process(MLInput mlInput) { | ||
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); | ||
Map<String, String> parametersMap = new HashMap<>(); | ||
parametersMap.put("inputText", inputData.getDocs().get(0)); | ||
if (inputData.getDocs().size() > 1) { | ||
parametersMap.put("inputImage", inputData.getDocs().get(1)); | ||
} | ||
return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap))).build(); | ||
|
||
} | ||
} |
99 changes: 99 additions & 0 deletions
99
...h/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunctionTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.connector.functions.preprocess; | ||
|
||
import org.junit.Before; | ||
import org.junit.Rule; | ||
import org.junit.Test; | ||
import org.junit.rules.ExpectedException; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
|
||
public class MultiModalConnectorPreProcessFunctionTest { | ||
@Rule | ||
public ExpectedException exceptionRule = ExpectedException.none(); | ||
|
||
MultiModalConnectorPreProcessFunction function; | ||
|
||
TextSimilarityInputDataSet textSimilarityInputDataSet; | ||
TextDocsInputDataSet textDocsInputDataSet; | ||
RemoteInferenceInputDataSet remoteInferenceInputDataSet; | ||
|
||
MLInput textEmbeddingInput; | ||
MLInput textSimilarityInput; | ||
MLInput remoteInferenceInput; | ||
|
||
@Before | ||
public void setUp() { | ||
function = new MultiModalConnectorPreProcessFunction(); | ||
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build(); | ||
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello", "world")).build(); | ||
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("inputText", "value1", "inputImage", "value2")).build(); | ||
|
||
textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build(); | ||
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build(); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenNullInput_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("Preprocess function input can't be null"); | ||
function.apply(null); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenWrongInput_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet"); | ||
function.apply(textSimilarityInput); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenCorrectInput_expectCorrectOutput() { | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(2, dataSet.getParameters().size()); | ||
assertEquals("hello", dataSet.getParameters().get("inputText")); | ||
assertEquals("world", dataSet.getParameters().get("inputImage")); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenInputTextOnly_expectInputTextShowUp() { | ||
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
assertEquals(1, dataSet.getParameters().size()); | ||
assertEquals("hello", dataSet.getParameters().get("inputText")); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenInputTextIsnull_expectIllegalArgumentException() { | ||
exceptionRule.expect(IllegalArgumentException.class); | ||
exceptionRule.expectMessage("No input text or image provided"); | ||
List<String> docs = new ArrayList<>(); | ||
docs.add(null); | ||
TextDocsInputDataSet textDocsInputDataSet1 = TextDocsInputDataSet.builder().docs(docs).build(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet1).build(); | ||
RemoteInferenceInputDataSet dataSet = function.apply(mlInput); | ||
} | ||
|
||
@Test | ||
public void testProcess_whenRemoteInferenceInput_expectRemoteInferenceInputDataSet() { | ||
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput); | ||
assertEquals(remoteInferenceInputDataSet, dataSet); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.