Skip to content

Commit

Permalink
add text docs ML input (opensearch-project#830)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Mar 28, 2023
1 parent 466b136 commit a649ee6
Show file tree
Hide file tree
Showing 8 changed files with 376 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutputType;
Expand All @@ -30,6 +31,7 @@ public class MLCommonsClassLoader {
private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> mlInputClassMap = new HashMap<>();

static {
try {
Expand All @@ -51,6 +53,7 @@ public static void loadClassMapping() {
loadMLInputDataSetClassMapping();
loadExecuteInputClassMapping();
loadExecuteOutputClassMapping();
loadMLInputClassMapping();
} finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
Expand Down Expand Up @@ -160,6 +163,22 @@ private static void loadExecuteOutputClassMapping() {
}
}

private static void loadMLInputClassMapping() {
Reflections reflections = new Reflections("org.opensearch.ml.common.input");
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLInput.class);
for (Class<?> clazz : classes) {
MLInput mlInput = clazz.getAnnotation(MLInput.class);
if (mlInput != null) {
FunctionName[] algorithms = mlInput.functionNames();
if (algorithms != null && algorithms.length > 0) {
for(FunctionName name : algorithms){
mlInputClassMap.put(name, clazz);
}
}
}
}
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initMLInstance(T type, I in, Class<?> constructorParamClass) {
return init(parameterClassMap, type, in, constructorParamClass);
Expand Down Expand Up @@ -195,4 +214,33 @@ private static <T extends Enum<T>, S, I extends Object> S init(Map<Enum<?>, Clas
}
}

public static boolean canInitMLInput(FunctionName functionName) {
return mlInputClassMap.containsKey(functionName);
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S> S initMLInput(T type, Object[] initArgs,
Class<?>... constructorParameterTypes) {
return init(mlInputClassMap, type, initArgs, constructorParameterTypes);
}

private static <T extends Enum<T>, S> S init(Map<Enum<?>, Class<?>> map, T type,
Object[] initArgs, Class<?>... constructorParameterTypes) {
Class<?> clazz = map.get(type);
if (clazz == null) {
throw new IllegalArgumentException("Can't find class for type " + type);
}
try {
Constructor<?> constructor = clazz.getConstructor(constructorParameterTypes);
return (S) constructor.newInstance(initArgs);
} catch (Exception e) {
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
} else {
log.error("Failed to init instance for type " + type, e);
return null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.annotation;

import org.opensearch.ml.common.FunctionName;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface MLInput {
// supported algorithms
FunctionName[] functionNames();
}
17 changes: 13 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.XContentBuilder;
Expand All @@ -32,9 +33,10 @@
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* ML input data: algirithm name, parameters and input data set.
* ML input data: algorithm name, parameters and input data set.
*/
@Data
@NoArgsConstructor
public class MLInput implements Input {

public static final String ALGORITHM_FIELD = "algorithm";
Expand All @@ -56,11 +58,11 @@ public class MLInput implements Input {
public static final String TEXT_DOCS_FIELD = "text_docs";

// Algorithm name
private FunctionName algorithm;
protected FunctionName algorithm;
// ML algorithm parameters
private MLAlgoParams parameters;
protected MLAlgoParams parameters;
// Input data to train model, run trained model to predict or run ML algorithms(no-model-based) directly.
private MLInputDataset inputDataset;
protected MLInputDataset inputDataset;

private int version = 1;

Expand Down Expand Up @@ -169,6 +171,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
FunctionName algorithm = FunctionName.from(algorithmName);

if (MLCommonsClassLoader.canInitMLInput(algorithm)) {
MLInput mlInput = MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class);
mlInput.setAlgorithm(algorithm);
return mlInput;
}

MLAlgoParams mlParameters = null;
SearchSourceBuilder searchSourceBuilder = null;
List<String> sourceIndices = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package org.opensearch.ml.common.input.nlp;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelResultFilter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";

public TextDocsMLInput(FunctionName algorithm, MLInputDataset inputDataset) {
super(algorithm, null, inputDataset);
}

public TextDocsMLInput(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALGORITHM_FIELD, algorithm.name());
if (parameters != null) {
builder.field(ML_PARAMETERS_FIELD, parameters);
}
if (inputDataset != null) {
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
List<String> docs = textInputDataSet.getDocs();
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
if (docs != null && docs.size() > 0) {
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
}
if (resultFilter != null) {
builder.startObject(RESULT_FILTER_FIELD);
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
List<String> targetResponse = resultFilter.getTargetResponse();
if (targetResponse != null && targetResponse.size() > 0) {
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
}
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
if (targetPositions != null && targetPositions.size() > 0) {
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
builder.endObject();
}
}
builder.endObject();
return builder;
}

public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
List<String> docs = new ArrayList<>();
ModelResultFilter resultFilter = null;

boolean returnBytes = false;
boolean returnNumber = true;
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case RETURN_BYTES_FIELD:
returnBytes = parser.booleanValue();
break;
case RETURN_NUMBER_FIELD:
returnNumber = parser.booleanValue();
break;
case TARGET_RESPONSE_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponse.add(parser.text());
}
break;
case TARGET_RESPONSE_POSITIONS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponsePositions.add(parser.intValue());
}
break;
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
docs.add(parser.text());
}
break;
case RESULT_FILTER_FIELD:
resultFilter = ModelResultFilter.parse(parser);
break;
default:
parser.skipChildren();
break;
}
}
ModelResultFilter filter = resultFilter != null ? resultFilter : ModelResultFilter.builder().returnBytes(returnBytes)
.returnNumber(returnNumber).targetResponse(targetResponse).targetResponsePositions(targetResponsePositions)
.build();

if (docs.size() == 0) {
throw new IllegalArgumentException("Empty text docs");
}
inputDataset = new TextDocsInputDataSet(docs, filter);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,29 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* This class is to filter model results.
*/
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class ModelResultFilter implements Writeable {

public static final String RETURN_BYTES_FIELD = "return_bytes";
// Return bytes in model output. This can be used together with return_bytes.
public static final String RETURN_NUMBER_FIELD = "return_number";
// Filter target response with name in model output
public static final String TARGET_RESPONSE_FIELD = "target_response";
// Filter target response with position in model output
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";

// Return model output as bytes. This could be useful if client side prefer
// to parse the model output in its own way.
protected boolean returnBytes;
Expand Down Expand Up @@ -77,4 +88,41 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeBoolean(false);
}
}

public static ModelResultFilter parse(XContentParser parser) throws IOException {
boolean returnBytes = false;
boolean returnNumber = true;
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case RETURN_BYTES_FIELD:
returnBytes = parser.booleanValue();
break;
case RETURN_NUMBER_FIELD:
returnNumber = parser.booleanValue();
break;
case TARGET_RESPONSE_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponse.add(parser.text());
}
break;
case TARGET_RESPONSE_POSITIONS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponsePositions.add(parser.intValue());
}
break;
default:
parser.skipChildren();
break;
}
}
return new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
}
}
Loading

0 comments on commit a649ee6

Please sign in to comment.