forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add text docs ML input (opensearch-project#830)
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
- Loading branch information
Showing
8 changed files
with
376 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.common.annotation; | ||
|
||
import org.opensearch.ml.common.FunctionName; | ||
|
||
import java.lang.annotation.ElementType; | ||
import java.lang.annotation.Retention; | ||
import java.lang.annotation.RetentionPolicy; | ||
import java.lang.annotation.Target; | ||
|
||
@Retention(RetentionPolicy.RUNTIME) | ||
@Target(ElementType.TYPE) | ||
public @interface MLInput { | ||
// supported algorithms | ||
FunctionName[] functionNames(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package org.opensearch.ml.common.input.nlp; | ||
|
||
import org.opensearch.common.io.stream.StreamInput; | ||
import org.opensearch.common.io.stream.StreamOutput; | ||
import org.opensearch.common.xcontent.XContentBuilder; | ||
import org.opensearch.common.xcontent.XContentParser; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
import org.opensearch.ml.common.dataset.TextDocsInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.output.model.ModelResultFilter; | ||
|
||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
/** | ||
* ML input class which supports a list fo text docs. | ||
* This class can be used for TEXT_EMBEDDING model. | ||
*/ | ||
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING}) | ||
public class TextDocsMLInput extends MLInput { | ||
public static final String TEXT_DOCS_FIELD = "text_docs"; | ||
public static final String RESULT_FILTER_FIELD = "result_filter"; | ||
|
||
public TextDocsMLInput(FunctionName algorithm, MLInputDataset inputDataset) { | ||
super(algorithm, null, inputDataset); | ||
} | ||
|
||
public TextDocsMLInput(StreamInput in) throws IOException { | ||
super(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(ALGORITHM_FIELD, algorithm.name()); | ||
if (parameters != null) { | ||
builder.field(ML_PARAMETERS_FIELD, parameters); | ||
} | ||
if (inputDataset != null) { | ||
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset; | ||
List<String> docs = textInputDataSet.getDocs(); | ||
ModelResultFilter resultFilter = textInputDataSet.getResultFilter(); | ||
if (docs != null && docs.size() > 0) { | ||
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0])); | ||
} | ||
if (resultFilter != null) { | ||
builder.startObject(RESULT_FILTER_FIELD); | ||
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes()); | ||
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber()); | ||
List<String> targetResponse = resultFilter.getTargetResponse(); | ||
if (targetResponse != null && targetResponse.size() > 0) { | ||
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0])); | ||
} | ||
List<Integer> targetPositions = resultFilter.getTargetResponsePositions(); | ||
if (targetPositions != null && targetPositions.size() > 0) { | ||
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0])); | ||
} | ||
builder.endObject(); | ||
} | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws IOException { | ||
super(); | ||
this.algorithm = functionName; | ||
List<String> docs = new ArrayList<>(); | ||
ModelResultFilter resultFilter = null; | ||
|
||
boolean returnBytes = false; | ||
boolean returnNumber = true; | ||
List<String> targetResponse = new ArrayList<>(); | ||
List<Integer> targetResponsePositions = new ArrayList<>(); | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case RETURN_BYTES_FIELD: | ||
returnBytes = parser.booleanValue(); | ||
break; | ||
case RETURN_NUMBER_FIELD: | ||
returnNumber = parser.booleanValue(); | ||
break; | ||
case TARGET_RESPONSE_FIELD: | ||
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_ARRAY) { | ||
targetResponse.add(parser.text()); | ||
} | ||
break; | ||
case TARGET_RESPONSE_POSITIONS_FIELD: | ||
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_ARRAY) { | ||
targetResponsePositions.add(parser.intValue()); | ||
} | ||
break; | ||
case TEXT_DOCS_FIELD: | ||
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_ARRAY) { | ||
docs.add(parser.text()); | ||
} | ||
break; | ||
case RESULT_FILTER_FIELD: | ||
resultFilter = ModelResultFilter.parse(parser); | ||
break; | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
ModelResultFilter filter = resultFilter != null ? resultFilter : ModelResultFilter.builder().returnBytes(returnBytes) | ||
.returnNumber(returnNumber).targetResponse(targetResponse).targetResponsePositions(targetResponsePositions) | ||
.build(); | ||
|
||
if (docs.size() == 0) { | ||
throw new IllegalArgumentException("Empty text docs"); | ||
} | ||
inputDataset = new TextDocsInputDataSet(docs, filter); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.