Skip to content

Commit

Permalink
Add ml licence check to the pipeline inference agg. (#59213)
Browse files Browse the repository at this point in the history
Ensures the licence is sufficient for the model used in inference
  • Loading branch information
davidkyle authored Jul 13, 2020
1 parent 55b6c1a commit 3202f46
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
*/
package org.elasticsearch.license;

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.bytes.BytesArray;
Expand All @@ -21,6 +25,8 @@
import org.elasticsearch.license.License.OperationMode;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.MlConfigIndex;
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
Expand All @@ -46,12 +52,15 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
Expand Down Expand Up @@ -140,7 +149,7 @@ public void testMachineLearningOpenJobActionRestricted() throws Exception {
assertNotNull(response2);
}

public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
public void testMachineLearningPutDatafeedActionRestricted() {
String jobId = "testmachinelearningputdatafeedactionrestricted";
String datafeedId = jobId + "-datafeed";
assertMLAllowed(true);
Expand Down Expand Up @@ -431,7 +440,7 @@ public void testMachineLearningCloseJobActionNotRestricted() throws Exception {
}
}

public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
public void testMachineLearningDeleteJobActionNotRestricted() {
String jobId = "testmachinelearningclosejobactionnotrestricted";
assertMLAllowed(true);
// test that license restricted apis do now work
Expand All @@ -449,7 +458,7 @@ public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
listener.actionGet();
}

public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
public void testMachineLearningDeleteDatafeedActionNotRestricted() {
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
String datafeedId = jobId + "-datafeed";
assertMLAllowed(true);
Expand All @@ -474,7 +483,7 @@ public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Except
listener.actionGet();
}

public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception {
public void testMachineLearningCreateInferenceProcessorRestricted() {
String modelId = "modelprocessorlicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
Expand Down Expand Up @@ -606,7 +615,7 @@ public void testMachineLearningCreateInferenceProcessorRestricted() throws Excep
.actionGet();
}

public void testMachineLearningInferModelRestricted() throws Exception {
public void testMachineLearningInferModelRestricted() {
String modelId = "modelinfermodellicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
Expand Down Expand Up @@ -668,20 +677,71 @@ public void testMachineLearningInferModelRestricted() throws Exception {
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}

public void testInferenceAggRestricted() {
String modelId = "inference-agg-restricted";
assertMLAllowed(true);
putInferenceModel(modelId);

// index some data
String index = "inference-agg-licence-test";
client().admin().indices().prepareCreate(index).setMapping("feature1", "type=double", "feature2", "type=keyword").get();
client().prepareBulk(index)
.add(new IndexRequest().source("feature1", "10.0", "feature2", "foo"))
.add(new IndexRequest().source("feature1", "20.0", "feature2", "foo"))
.add(new IndexRequest().source("feature1", "20.0", "feature2", "bar"))
.add(new IndexRequest().source("feature1", "20.0", "feature2", "bar"))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();

TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
termsAgg.subAggregation(avgAgg);

XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);

Map<String, String> bucketPaths = new HashMap<>();
bucketPaths.put("feature1", "avg_feature1");
InferencePipelineAggregationBuilder inferenceAgg =
new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
inferenceAgg.setModelId(modelId);

termsAgg.subAggregation(inferenceAgg);

SearchRequest search = new SearchRequest(index);
search.source().aggregation(termsAgg);
client().search(search).actionGet();

// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);

// inferring against a model should now fail
SearchRequest invalidSearch = new SearchRequest(index);
invalidSearch.source().aggregation(termsAgg);
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> client().search(invalidSearch).actionGet());

assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
}

private void putInferenceModel(String modelId) {
TrainedModelConfig config = TrainedModelConfig.builder()
.setParsedDefinition(
new TrainedModelDefinition.Builder()
.setTrainedModel(
Tree.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(Arrays.asList("feature1"))
.setFeatureNames(Collections.singletonList("feature1"))
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
.build())
.setPreProcessors(Collections.emptyList()))
.setModelId(modelId)
.setDescription("test model for classification")
.setInput(new TrainedModelInput(Arrays.asList("feature1")))
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -988,10 +988,9 @@ public Map<String, AnalysisProvider<TokenizerFactory>> getTokenizers() {
@Override
public List<PipelineAggregationSpec> getPipelineAggregations() {
PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
(ContextParser<String, ? extends PipelineAggregationBuilder>)
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
));
(parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
spec.addResultReader(InternalInferenceAggregation::new);

return Collections.singletonList(spec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
Expand All @@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
static final String AGGREGATIONS_RESULTS_FIELD = "value";

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
NAME, false,
(args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER =
new ConstructingObjectParser<>(NAME, false,
(args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService,
context.licenseState, (Map<String, String>) args[0])
);

static {
Expand All @@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
private final Map<String, String> bucketPathMap;
private String modelId;
private InferenceConfigUpdate inferenceConfig;
private final XPackLicenseState licenseState;
private final SetOnce<ModelLoadingService> modelLoadingService;
/**
* The model. Set to a non-null value during the rewrite phase.
*/
private final Supplier<LocalModel> model;

private static class ParserSupplement {
final XPackLicenseState licenseState;
final SetOnce<ModelLoadingService> modelLoadingService;
final String name;

ParserSupplement(String name, XPackLicenseState licenseState, SetOnce<ModelLoadingService> modelLoadingService) {
this.name = name;
this.licenseState = licenseState;
this.modelLoadingService = modelLoadingService;
}
}
public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
XPackLicenseState licenseState,
String pipelineAggregatorName,
XContentParser parser) {
Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
return PARSER.apply(parser, context);
return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService));
}

public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
public InferencePipelineAggregationBuilder(String name,
SetOnce<ModelLoadingService> modelLoadingService,
XPackLicenseState licenseState,
Map<String, String> bucketsPath) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
this.modelLoadingService = modelLoadingService;
this.bucketPathMap = bucketsPath;
this.model = null;
this.licenseState = licenseState;
}

public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
public InferencePipelineAggregationBuilder(StreamInput in,
XPackLicenseState licenseState,
SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
super(in, NAME);
modelId = in.readString();
bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
this.modelLoadingService = modelLoadingService;
this.model = null;
this.licenseState = licenseState;
}

/**
Expand All @@ -98,7 +118,8 @@ private InferencePipelineAggregationBuilder(
Map<String, String> bucketsPath,
Supplier<LocalModel> model,
String modelId,
InferenceConfigUpdate inferenceConfig
InferenceConfigUpdate inferenceConfig,
XPackLicenseState licenseState
) {
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
modelLoadingService = null;
Expand All @@ -113,13 +134,14 @@ private InferencePipelineAggregationBuilder(
*/
this.modelId = modelId;
this.inferenceConfig = inferenceConfig;
this.licenseState = licenseState;
}

void setModelId(String modelId) {
public void setModelId(String modelId) {
this.modelId = modelId;
}

void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
this.inferenceConfig = inferenceConfig;
}

Expand Down Expand Up @@ -160,18 +182,25 @@ protected void doWriteTo(StreamOutput out) throws IOException {
}

@Override
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException {
public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
if (model != null) {
return this;
}
SetOnce<LocalModel> loadedModel = new SetOnce<>();
context.registerAsyncAction((client, listener) -> {
modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
loadedModel.set(model);
delegate.onResponse(null);

boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) ||
licenseState.isAllowedByLicense(model.getLicenseLevel());
if (isLicensed) {
delegate.onResponse(null);
} else {
delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}));
});
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig);
return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.inference.loadingservice;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -38,13 +39,15 @@ public class LocalModel {
private volatile long persistenceQuotient = 100;
private final LongAdder currentInferenceCount;
private final InferenceConfig inferenceConfig;
private final License.OperationMode licenseLevel;

public LocalModel(String modelId,
String nodeId,
InferenceDefinition trainedModelDefinition,
TrainedModelInput input,
Map<String, String> defaultFieldMap,
InferenceConfig modelInferenceConfig,
License.OperationMode licenseLevel,
TrainedModelStatsService trainedModelStatsService) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
Expand All @@ -56,6 +59,7 @@ public LocalModel(String modelId,
this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
this.currentInferenceCount = new LongAdder();
this.inferenceConfig = modelInferenceConfig;
this.licenseLevel = licenseLevel;
}

long ramBytesUsed() {
Expand All @@ -66,6 +70,10 @@ public String getModelId() {
return modelId;
}

public License.OperationMode getLicenseLevel() {
return licenseLevel;
}

public InferenceStats getLatestStatsAndReset() {
return statsAccumulator.currentStatsAndReset();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ private void loadWithoutCaching(String modelId, ActionListener<LocalModel> model
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService));
},
// Failure getting the definition, remove the initial estimation value
Expand Down Expand Up @@ -337,6 +338,7 @@ private void handleLoadSuccess(String modelId,
trainedModelConfig.getInput(),
trainedModelConfig.getDefaultFieldMap(),
inferenceConfig,
trainedModelConfig.getLicenseLevel(),
modelStatsService);
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
Expand Down
Loading

0 comments on commit 3202f46

Please sign in to comment.