Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] add tags url param to GET #51330

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,9 @@ static Request getTrainedModels(GetTrainedModelsRequest getTrainedModelsRequest)
params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
}
if (getTrainedModelsRequest.getTags() != null) {
params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
}
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
request.addParameters(params.asMap());
return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.client.Validatable;
import org.elasticsearch.client.ValidationException;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
import org.elasticsearch.common.Nullable;

import java.util.Arrays;
Expand All @@ -34,12 +35,14 @@ public class GetTrainedModelsRequest implements Validatable {
public static final String ALLOW_NO_MATCH = "allow_no_match";
public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
public static final String DECOMPRESS_DEFINITION = "decompress_definition";
public static final String TAGS = "tags";

private final List<String> ids;
private Boolean allowNoMatch;
private Boolean includeDefinition;
private Boolean decompressDefinition;
private PageParams pageParams;
private List<String> tags;

/**
* Helper method to create a request that will get ALL TrainedModelConfigs
Expand Down Expand Up @@ -111,6 +114,29 @@ public GetTrainedModelsRequest setDecompressDefinition(Boolean decompressDefinit
return this;
}

public List<String> getTags() {
return tags;
}

/**
* The tags that the trained model must match. These correspond to {@link TrainedModelConfig#getTags()}.
*
* The models returned will match ALL tags supplied.
* If none are provided, only the provided ids are used to find models
* @param tags The tags to match when finding models
*/
public GetTrainedModelsRequest setTags(List<String> tags) {
this.tags = tags;
return this;
}

/**
* See {@link GetTrainedModelsRequest#setTags(List)}
*/
public GetTrainedModelsRequest setTags(String... tags) {
return setTags(Arrays.asList(tags));
}

@Override
public Optional<ValidationException> validate() {
if (ids == null || ids.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ public void testGetTrainedModels() {
.setAllowNoMatch(false)
.setDecompressDefinition(true)
.setIncludeDefinition(false)
.setTags("tag1", "tag2")
.setPageParams(new PageParams(100, 300));

Request request = MLRequestConverters.getTrainedModels(getRequest);
Expand All @@ -845,6 +846,7 @@ public void testGetTrainedModels() {
hasEntry("size", "300"),
hasEntry("allow_no_match", "false"),
hasEntry("decompress_definition", "true"),
hasEntry("tags", "tag1,tag2"),
hasEntry("include_model_definition", "false")
));
assertNull(request.getEntity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3587,8 +3587,10 @@ public void testGetTrainedModels() throws Exception {
.setPageParams(new PageParams(0, 1)) // <2>
.setIncludeDefinition(false) // <3>
.setDecompressDefinition(false) // <4>
.setAllowNoMatch(true); // <5>
.setAllowNoMatch(true) // <5>
.setTags("regression"); // <6>
// end::get-trained-models-request
request.setTags((List<String>)null);

// tag::get-trained-models-execute
GetTrainedModelsResponse response = client.machineLearning().getTrainedModels(request, RequestOptions.DEFAULT);
Expand Down
3 changes: 3 additions & 0 deletions docs/java-rest/high-level/ml/get-trained-models.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ include-tagged::{doc-tests-file}[{api}-request]
<5> Allow empty response if no Trained Models match the provided ID patterns.
If false, an error will be thrown if no Trained Models match the
ID patterns.
<6> An optional list of tags used to narrow the model search. A Trained Model
can have many tags or none. The trained models in the response will
contain all the provided tags.

include::../execution.asciidoc[]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=include-model-definition]
(Optional, integer)
include::{docdir}/ml/ml-shared.asciidoc[tag=size]

`tags`::
(Optional, string)
include::{docdir}/ml/ml-shared.asciidoc[tag=tags]

[[ml-get-inference-response-codes]]
==== {api-response-codes-title}
Expand All @@ -97,4 +100,4 @@ The following example gets configuration information for all the trained models:
--------------------------------------------------
GET _ml/inference/
--------------------------------------------------
// TEST[skip:TBD]
// TEST[skip:TBD]
6 changes: 6 additions & 0 deletions docs/reference/ml/ml-shared.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,12 @@ to `false`. When `true`, only a single model must match the ID patterns
provided, otherwise a bad request is returned.
end::include-model-definition[]

tag::tags[]
A comma delimited string of tags. A {infer} model can have many tags, or none.
When supplied, only {infer} models that contain all the supplied tags are
returned.
end::tags[]

tag::indices[]
An array of index names. Wildcards are supported. For example:
`["it_ops_metrics", "server*"]`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -33,18 +34,26 @@ public static class Request extends AbstractGetResourcesRequest {

public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public static final ParseField TAGS = new ParseField("tags");

private final boolean includeModelDefinition;
private final List<String> tags;

public Request(String id, boolean includeModelDefinition) {
public Request(String id, boolean includeModelDefinition, List<String> tags) {
setResourceId(id);
setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
this.tags = tags == null ? Collections.emptyList() : tags;
}

public Request(StreamInput in) throws IOException {
super(in);
this.includeModelDefinition = in.readBoolean();
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
this.tags = in.readStringList();
} else {
this.tags = Collections.emptyList();
}
}

@Override
Expand All @@ -56,15 +65,22 @@ public boolean isIncludeModelDefinition() {
return includeModelDefinition;
}

public List<String> getTags() {
return tags;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(includeModelDefinition);
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeStringCollection(tags);
}
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition);
return Objects.hash(super.hashCode(), includeModelDefinition, tags);
}

@Override
Expand All @@ -76,7 +92,7 @@ public boolean equals(Object obj) {
return false;
}
Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas

@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
Request request = new Request(randomAlphaOfLength(20),
randomBoolean(),
randomBoolean() ? null :
randomList(10, () -> randomAlphaOfLength(10)));
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;


Expand Down Expand Up @@ -70,7 +71,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
listener::onFailure
);

provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
provider.expandIds(request.getResourceId(),
request.isAllowNoResources(),
request.getPageParams(),
new HashSet<>(request.getTags()),
idExpansionListener);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -94,7 +95,11 @@ protected void doExecute(Task task,
listener::onFailure
);

trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
trainedModelProvider.expandIds(request.getResourceId(),
request.isAllowNoResources(),
request.getPageParams(),
Collections.emptySet(),
idsListener);
}

static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,24 @@ public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener)
public void expandIds(String idExpression,
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
boolean allowNoResources,
@Nullable PageParams pageParams,
Set<String> tags,
ActionListener<Tuple<Long, Set<String>>> idsListener) {
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
BoolQueryBuilder tagQuery = QueryBuilders.boolQuery();
for(String tag : tags) {
tagQuery.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag));
}

QueryBuilder query = QueryBuilders.constantScoreQuery(
QueryBuilders.boolQuery()
.filter(tagQuery.hasClauses() ? tagQuery : QueryBuilders.matchAllQuery())
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())));
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
// If there are no resources, there might be no mapping for the id field.
// This makes sure we don't get an error if that happens.
.unmappedType("long"))
.query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
.query(query);
if (pageParams != null) {
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
}
Expand All @@ -404,13 +414,23 @@ public void expandIds(String idExpression,
indicesOptions.expandWildcardsClosed(),
indicesOptions))
.source(sourceBuilder);
Set<String> foundResourceIds = new LinkedHashSet<>();
if (tags.isEmpty()) {
foundResourceIds.addAll(matchedResourceIds(tokens));
} else {
for(String resourceId : matchedResourceIds(tokens)) {
// Does the model as a resource have all the tags?
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved
foundResourceIds.add(resourceId);
}
}
}

executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(
response -> {
Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSource = hit.getSourceAsMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import org.elasticsearch.xpack.ml.MachineLearning;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;

import static org.elasticsearch.rest.RestRequest.Method.GET;
Expand Down Expand Up @@ -47,7 +49,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
false
);
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
List<String> tags = Arrays.asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,21 @@
"description": "Should the model definition be decompressed into valid JSON or returned in a custom compressed format. Defaults to true."
},
"from":{
"required": false,
"type":"int",
"description":"skips a number of trained models",
"default":0
},
"size":{
"required": false,
"type":"int",
"description":"specifies a max number of trained models to get",
"default":100
},
"tags": {
"required": false,
"type":"list",
"description":"A comma-separated list of tags that the model must have."
}
}
}
Expand Down
Loading