Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport changes to model-access-control feature branch #837

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.annotation.MLAlgoOutput;
import org.opensearch.ml.common.annotation.MLAlgoParameter;
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutputType;
Expand All @@ -30,6 +31,7 @@ public class MLCommonsClassLoader {
private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> executeInputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> executeOutputClassMap = new HashMap<>();
private static Map<Enum<?>, Class<?>> mlInputClassMap = new HashMap<>();

static {
try {
Expand All @@ -51,6 +53,7 @@ public static void loadClassMapping() {
loadMLInputDataSetClassMapping();
loadExecuteInputClassMapping();
loadExecuteOutputClassMapping();
loadMLInputClassMapping();
} finally {
Thread.currentThread().setContextClassLoader(originalClassLoader);
}
Expand Down Expand Up @@ -160,6 +163,22 @@ private static void loadExecuteOutputClassMapping() {
}
}

private static void loadMLInputClassMapping() {
Reflections reflections = new Reflections("org.opensearch.ml.common.input");
Set<Class<?>> classes = reflections.getTypesAnnotatedWith(MLInput.class);
for (Class<?> clazz : classes) {
MLInput mlInput = clazz.getAnnotation(MLInput.class);
if (mlInput != null) {
FunctionName[] algorithms = mlInput.functionNames();
if (algorithms != null && algorithms.length > 0) {
for(FunctionName name : algorithms){
mlInputClassMap.put(name, clazz);
}
}
}
}
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initMLInstance(T type, I in, Class<?> constructorParamClass) {
return init(parameterClassMap, type, in, constructorParamClass);
Expand Down Expand Up @@ -195,4 +214,33 @@ private static <T extends Enum<T>, S, I extends Object> S init(Map<Enum<?>, Clas
}
}

public static boolean canInitMLInput(FunctionName functionName) {
return mlInputClassMap.containsKey(functionName);
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S> S initMLInput(T type, Object[] initArgs,
Class<?>... constructorParameterTypes) {
return init(mlInputClassMap, type, initArgs, constructorParameterTypes);
}

private static <T extends Enum<T>, S> S init(Map<Enum<?>, Class<?>> map, T type,
Object[] initArgs, Class<?>... constructorParameterTypes) {
Class<?> clazz = map.get(type);
if (clazz == null) {
throw new IllegalArgumentException("Can't find class for type " + type);
}
try {
Constructor<?> constructor = clazz.getConstructor(constructorParameterTypes);
return (S) constructor.newInstance(initArgs);
} catch (Exception e) {
Throwable cause = e.getCause();
if (cause instanceof MLException) {
throw (MLException)cause;
} else {
log.error("Failed to init instance for type " + type, e);
return null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.annotation;

import org.opensearch.ml.common.FunctionName;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface MLInput {
// supported algorithms
FunctionName[] functionNames();
}
17 changes: 13 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.XContentBuilder;
Expand All @@ -32,9 +33,10 @@
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* ML input data: algirithm name, parameters and input data set.
* ML input data: algorithm name, parameters and input data set.
*/
@Data
@NoArgsConstructor
public class MLInput implements Input {

public static final String ALGORITHM_FIELD = "algorithm";
Expand All @@ -56,11 +58,11 @@ public class MLInput implements Input {
public static final String TEXT_DOCS_FIELD = "text_docs";

// Algorithm name
private FunctionName algorithm;
protected FunctionName algorithm;
// ML algorithm parameters
private MLAlgoParams parameters;
protected MLAlgoParams parameters;
// Input data to train model, run trained model to predict or run ML algorithms(no-model-based) directly.
private MLInputDataset inputDataset;
protected MLInputDataset inputDataset;

private int version = 1;

Expand Down Expand Up @@ -169,6 +171,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
FunctionName algorithm = FunctionName.from(algorithmName);

if (MLCommonsClassLoader.canInitMLInput(algorithm)) {
MLInput mlInput = MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class);
mlInput.setAlgorithm(algorithm);
return mlInput;
}

MLAlgoParams mlParameters = null;
SearchSourceBuilder searchSourceBuilder = null;
List<String> sourceIndices = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package org.opensearch.ml.common.input.nlp;

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelResultFilter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";

public TextDocsMLInput(FunctionName algorithm, MLInputDataset inputDataset) {
super(algorithm, null, inputDataset);
}

public TextDocsMLInput(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALGORITHM_FIELD, algorithm.name());
if (parameters != null) {
builder.field(ML_PARAMETERS_FIELD, parameters);
}
if (inputDataset != null) {
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
List<String> docs = textInputDataSet.getDocs();
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
if (docs != null && docs.size() > 0) {
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
}
if (resultFilter != null) {
builder.startObject(RESULT_FILTER_FIELD);
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
List<String> targetResponse = resultFilter.getTargetResponse();
if (targetResponse != null && targetResponse.size() > 0) {
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
}
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
if (targetPositions != null && targetPositions.size() > 0) {
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
builder.endObject();
}
}
builder.endObject();
return builder;
}

public TextDocsMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
List<String> docs = new ArrayList<>();
ModelResultFilter resultFilter = null;

boolean returnBytes = false;
boolean returnNumber = true;
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case RETURN_BYTES_FIELD:
returnBytes = parser.booleanValue();
break;
case RETURN_NUMBER_FIELD:
returnNumber = parser.booleanValue();
break;
case TARGET_RESPONSE_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponse.add(parser.text());
}
break;
case TARGET_RESPONSE_POSITIONS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponsePositions.add(parser.intValue());
}
break;
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
docs.add(parser.text());
}
break;
case RESULT_FILTER_FIELD:
resultFilter = ModelResultFilter.parse(parser);
break;
default:
parser.skipChildren();
break;
}
}
ModelResultFilter filter = resultFilter != null ? resultFilter : ModelResultFilter.builder().returnBytes(returnBytes)
.returnNumber(returnNumber).targetResponse(targetResponse).targetResponsePositions(targetResponsePositions)
.build();

if (docs.size() == 0) {
throw new IllegalArgumentException("Empty text docs");
}
inputDataset = new TextDocsInputDataSet(docs, filter);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,29 @@
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;

/**
* This class is to filter model results.
*/
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class ModelResultFilter implements Writeable {

public static final String RETURN_BYTES_FIELD = "return_bytes";
// Return bytes in model output. This can be used together with return_bytes.
public static final String RETURN_NUMBER_FIELD = "return_number";
// Filter target response with name in model output
public static final String TARGET_RESPONSE_FIELD = "target_response";
// Filter target response with position in model output
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";

// Return model output as bytes. This could be useful if client side prefer
// to parse the model output in its own way.
protected boolean returnBytes;
Expand Down Expand Up @@ -77,4 +88,41 @@ public void writeTo(StreamOutput streamOutput) throws IOException {
streamOutput.writeBoolean(false);
}
}

public static ModelResultFilter parse(XContentParser parser) throws IOException {
boolean returnBytes = false;
boolean returnNumber = true;
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case RETURN_BYTES_FIELD:
returnBytes = parser.booleanValue();
break;
case RETURN_NUMBER_FIELD:
returnNumber = parser.booleanValue();
break;
case TARGET_RESPONSE_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponse.add(parser.text());
}
break;
case TARGET_RESPONSE_POSITIONS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponsePositions.add(parser.intValue());
}
break;
default:
parser.skipChildren();
break;
}
}
return new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.upload_chunk;

import org.opensearch.action.ActionType;

public class MLRegisterModelMetaAction extends ActionType<MLRegisterModelMetaResponse> {
public static MLRegisterModelMetaAction INSTANCE = new MLRegisterModelMetaAction();
public static final String NAME = "cluster:admin/opensearch/ml/register_model_meta";

private MLRegisterModelMetaAction() {
super(NAME, MLRegisterModelMetaResponse::new);
}

}
Loading