diff --git a/docs/changelog/111336.yaml b/docs/changelog/111336.yaml new file mode 100644 index 0000000000000..d5bf602cb7a88 --- /dev/null +++ b/docs/changelog/111336.yaml @@ -0,0 +1,5 @@ +pr: 111336 +summary: Use the same chunking configurations for models in the Elasticsearch service +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/113812.yaml b/docs/changelog/113812.yaml new file mode 100644 index 0000000000000..04498b4ae5f7e --- /dev/null +++ b/docs/changelog/113812.yaml @@ -0,0 +1,5 @@ +pr: 113812 +summary: Add Streaming Inference spec +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/114080.yaml b/docs/changelog/114080.yaml new file mode 100644 index 0000000000000..395768c46369a --- /dev/null +++ b/docs/changelog/114080.yaml @@ -0,0 +1,5 @@ +pr: 114080 +summary: Stream Cohere Completion +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/114231.yaml b/docs/changelog/114231.yaml new file mode 100644 index 0000000000000..61c447688edcf --- /dev/null +++ b/docs/changelog/114231.yaml @@ -0,0 +1,17 @@ +pr: 114231 +summary: Remove cluster state from `/_cluster/reroute` response +area: Allocation +type: breaking +issues: + - 88978 +breaking: + title: Remove cluster state from `/_cluster/reroute` response + area: REST API + details: >- + The `POST /_cluster/reroute` API no longer returns the cluster state in its + response. The `?metric` query parameter to this API now has no effect and + its use will be forbidden in a future version. + impact: >- + Cease usage of the `?metric` query parameter when calling the + `POST /_cluster/reroute` API. + notable: false diff --git a/docs/reference/cluster/reroute.asciidoc b/docs/reference/cluster/reroute.asciidoc index b4e4809ae73b4..429070f80b9bf 100644 --- a/docs/reference/cluster/reroute.asciidoc +++ b/docs/reference/cluster/reroute.asciidoc @@ -10,7 +10,7 @@ Changes the allocation of shards in a cluster. [[cluster-reroute-api-request]] ==== {api-request-title} -`POST /_cluster/reroute?metric=none` +`POST /_cluster/reroute` [[cluster-reroute-api-prereqs]] ==== {api-prereq-title} @@ -193,7 +193,7 @@ This is a short example of a simple reroute API call: [source,console] -------------------------------------------------- -POST /_cluster/reroute?metric=none +POST /_cluster/reroute { "commands": [ { diff --git a/docs/reference/commands/shard-tool.asciidoc b/docs/reference/commands/shard-tool.asciidoc index a2d9d557adf5e..b1e63740cede0 100644 --- a/docs/reference/commands/shard-tool.asciidoc +++ b/docs/reference/commands/shard-tool.asciidoc @@ -95,7 +95,7 @@ Changing allocation id V8QXk-QXSZinZMT-NvEq4w to tjm9Ve6uTBewVFAlfUMWjA You should run the following command to allocate this shard: -POST /_cluster/reroute?metric=none +POST /_cluster/reroute { "commands" : [ { diff --git a/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc b/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc index 99c3ecad03a9d..7acbc0bd23859 100644 --- a/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc +++ b/docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc @@ -225,6 +225,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -301,6 +312,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -397,6 +419,21 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`span`:::: +(Optional, integer) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] + +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -517,6 +554,21 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`span`:::: +(Optional, integer) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] + +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -608,6 +660,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] @@ -687,6 +750,21 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, integer) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] +`with_special_tokens`:::: +(Optional, boolean) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens] +======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`span`:::: +(Optional, integer) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-span] + `with_special_tokens`:::: (Optional, boolean) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert-with-special-tokens] @@ -790,6 +868,17 @@ include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizatio (Optional, string) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate] ======= +`deberta_v2`:::: +(Optional, object) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-deberta-v2] ++ +.Properties of deberta_v2 +[%collapsible%open] +======= +`truncate`:::: +(Optional, string) +include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate-deberta-v2] +======= `roberta`:::: (Optional, object) include::{es-ref-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta] diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index fabd495cdc525..993bb8cb894f9 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1298,10 +1298,11 @@ tag::wait_for_active_shards[] `wait_for_active_shards`:: + -- -(Optional, string) The number of shard copies that must be active before -proceeding with the operation. Set to `all` or any positive integer up -to the total number of shards in the index (`number_of_replicas+1`). -Default: 1, the primary shard. +(Optional, string) The number of copies of each shard that must be active +before proceeding with the operation. Set to `all` or any non-negative integer +up to the total number of copies of each shard in the index +(`number_of_replicas+1`). Defaults to `1`, meaning to wait just for each +primary shard to be active. See <>. -- diff --git a/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc b/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc index f1bd238a64fbf..dbcfbb1b615f9 100644 --- a/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc +++ b/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc @@ -89,6 +89,16 @@ PUT semantic-embeddings It will be used to generate the embeddings based on the input text. Every time you ingest data into the related `semantic_text` field, this endpoint will be used for creating the vector representation of the text. +[NOTE] +==== +If you're using web crawlers or connectors to generate indices, you have to +<> for these indices to +include the `semantic_text` field. Once the mapping is updated, you'll need to run +a full web crawl or a full connector sync. This ensures that all existing +documents are reprocessed and updated with the new semantic embeddings, +enabling semantic search on the updated data. +==== + [discrete] [[semantic-text-load-data]] @@ -118,6 +128,13 @@ Create the embeddings from the text by reindexing the data from the `test-data` The data in the `content` field will be reindexed into the `content` semantic text field of the destination index. The reindexed data will be processed by the {infer} endpoint associated with the `content` semantic text field. +[NOTE] +==== +This step uses the reindex API to simulate data ingestion. If you are working with data that has already been indexed, +rather than using the test-data set, reindexing is required to ensure that the data is processed by the {infer} endpoint +and the necessary embeddings are generated. +==== + [source,console] ------------------------------------------------------------ POST _reindex?wait_for_completion=false diff --git a/docs/reference/troubleshooting/common-issues/red-yellow-cluster-status.asciidoc b/docs/reference/troubleshooting/common-issues/red-yellow-cluster-status.asciidoc index cae4eb99dd54a..eb56a37562c31 100644 --- a/docs/reference/troubleshooting/common-issues/red-yellow-cluster-status.asciidoc +++ b/docs/reference/troubleshooting/common-issues/red-yellow-cluster-status.asciidoc @@ -2,12 +2,12 @@ === Red or yellow cluster health status A red or yellow cluster health status indicates one or more shards are not assigned to -a node. +a node. * **Red health status**: The cluster has some unassigned primary shards, which -means that some operations such as searches and indexing may fail. -* **Yellow health status**: The cluster has no unassigned primary shards but some -unassigned replica shards. This increases your risk of data loss and can degrade +means that some operations such as searches and indexing may fail. +* **Yellow health status**: The cluster has no unassigned primary shards but some +unassigned replica shards. This increases your risk of data loss and can degrade cluster performance. When your cluster has a red or yellow health status, it will continue to process @@ -16,8 +16,8 @@ cleanup activities until the cluster returns to green health status. For instanc some <> actions require the index on which they operate to have a green health status. -In many cases, your cluster will recover to green health status automatically. -If the cluster doesn't automatically recover, then you must <> +In many cases, your cluster will recover to green health status automatically. +If the cluster doesn't automatically recover, then you must <> the remaining problems so management and cleanup activities can proceed. [discrete] @@ -107,7 +107,7 @@ asynchronously in the background. [source,console] ---- -POST _cluster/reroute?metric=none +POST _cluster/reroute ---- [discrete] @@ -231,10 +231,10 @@ unassigned. See <>. If a node containing a primary shard is lost, {es} can typically replace it using a replica on another node. If you can't recover the node and replicas -don't exist or are irrecoverable, <> will report `no_valid_shard_copy` and you'll need to do one of the following: +don't exist or are irrecoverable, <> will report `no_valid_shard_copy` and you'll need to do one of the following: -* restore the missing data from <> +* restore the missing data from <> * index the missing data from its original data source * accept data loss on the index-level by running <> * accept data loss on the shard-level by executing <> allocate_stale_primary or allocate_empty_primary command with `accept_data_loss: true` @@ -246,7 +246,7 @@ resulting in data loss. + [source,console] ---- -POST _cluster/reroute?metric=none +POST _cluster/reroute { "commands": [ { diff --git a/libs/simdvec/build.gradle b/libs/simdvec/build.gradle index 8b676a15038c1..dab5c25b34679 100644 --- a/libs/simdvec/build.gradle +++ b/libs/simdvec/build.gradle @@ -23,14 +23,15 @@ dependencies { } } -tasks.named("compileMain21Java").configure { +// compileMain21Java does not exist within idea (see MrJarPlugin) so we cannot reference directly by name +tasks.matching { it.name == "compileMain21Java" }.configureEach { options.compilerArgs << '--add-modules=jdk.incubator.vector' // we remove Werror, since incubating suppression (-Xlint:-incubating) // is only support since JDK 22 options.compilerArgs -= '-Werror' } -test { +tasks.named('test').configure { if (JavaVersion.current().majorVersion.toInteger() >= 21) { jvmArgs '--add-modules=jdk.incubator.vector' } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java index 52ca5eea52c1a..128c16e163764 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/Database.java @@ -22,12 +22,17 @@ *

* A database has a set of properties that are valid to use with it (see {@link Database#properties()}), * as well as a list of default properties to use if no properties are specified (see {@link Database#defaultProperties()}). + *

+ * Some database providers have similar concepts but might have slightly different properties associated with those types. + * This can be accommodated, for example, by having a Foo value and a separate FooV2 value where the 'V' should be read as + * 'variant' or 'variation'. A V-less Database type is inherently the first variant/variation (i.e. V1). */ enum Database { City( Set.of( Property.IP, + Property.COUNTRY_IN_EUROPEAN_UNION, Property.COUNTRY_ISO_CODE, Property.CONTINENT_CODE, Property.COUNTRY_NAME, @@ -36,7 +41,9 @@ enum Database { Property.REGION_NAME, Property.CITY_NAME, Property.TIMEZONE, - Property.LOCATION + Property.LOCATION, + Property.POSTAL_CODE, + Property.ACCURACY_RADIUS ), Set.of( Property.COUNTRY_ISO_CODE, @@ -49,7 +56,14 @@ enum Database { ) ), Country( - Set.of(Property.IP, Property.CONTINENT_CODE, Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE), + Set.of( + Property.IP, + Property.CONTINENT_CODE, + Property.CONTINENT_NAME, + Property.COUNTRY_NAME, + Property.COUNTRY_IN_EUROPEAN_UNION, + Property.COUNTRY_ISO_CODE + ), Set.of(Property.CONTINENT_NAME, Property.COUNTRY_NAME, Property.COUNTRY_ISO_CODE) ), Asn( @@ -80,12 +94,15 @@ enum Database { Enterprise( Set.of( Property.IP, + Property.COUNTRY_CONFIDENCE, + Property.COUNTRY_IN_EUROPEAN_UNION, Property.COUNTRY_ISO_CODE, Property.COUNTRY_NAME, Property.CONTINENT_CODE, Property.CONTINENT_NAME, Property.REGION_ISO_CODE, Property.REGION_NAME, + Property.CITY_CONFIDENCE, Property.CITY_NAME, Property.TIMEZONE, Property.LOCATION, @@ -104,7 +121,10 @@ enum Database { Property.MOBILE_COUNTRY_CODE, Property.MOBILE_NETWORK_CODE, Property.USER_TYPE, - Property.CONNECTION_TYPE + Property.CONNECTION_TYPE, + Property.POSTAL_CODE, + Property.POSTAL_CONFIDENCE, + Property.ACCURACY_RADIUS ), Set.of( Property.COUNTRY_ISO_CODE, @@ -137,6 +157,18 @@ enum Database { Property.MOBILE_COUNTRY_CODE, Property.MOBILE_NETWORK_CODE ) + ), + AsnV2( + Set.of( + Property.IP, + Property.ASN, + Property.ORGANIZATION_NAME, + Property.NETWORK, + Property.DOMAIN, + Property.COUNTRY_ISO_CODE, + Property.TYPE + ), + Set.of(Property.IP, Property.ASN, Property.ORGANIZATION_NAME, Property.NETWORK) ); private final Set properties; @@ -187,12 +219,15 @@ public Set parseProperties(@Nullable final List propertyNames) enum Property { IP, + COUNTRY_CONFIDENCE, + COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, + CITY_CONFIDENCE, CITY_NAME, TIMEZONE, LOCATION, @@ -211,7 +246,11 @@ enum Property { MOBILE_COUNTRY_CODE, MOBILE_NETWORK_CODE, CONNECTION_TYPE, - USER_TYPE; + USER_TYPE, + TYPE, + POSTAL_CODE, + POSTAL_CONFIDENCE, + ACCURACY_RADIUS; /** * Parses a string representation of a property into an actual Property instance. Not all properties that exist are diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java index 990788978a0ca..3379fdff0633a 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpDataLookupFactories.java @@ -76,6 +76,7 @@ static Database getDatabase(final String databaseType) { return database; } + @Nullable static Function, IpDataLookup> getMaxmindLookup(final Database database) { return switch (database) { case City -> MaxmindIpDataLookups.City::new; @@ -86,6 +87,7 @@ static Function, IpDataLookup> getMaxmindLookup(final Dat case Domain -> MaxmindIpDataLookups.Domain::new; case Enterprise -> MaxmindIpDataLookups.Enterprise::new; case Isp -> MaxmindIpDataLookups.Isp::new; + default -> null; }; } @@ -97,7 +99,6 @@ static IpDataLookupFactory get(final String databaseType, final String databaseF final Function, IpDataLookup> factoryMethod = getMaxmindLookup(database); - // note: this can't presently be null, but keep this check -- it will be useful in the near future if (factoryMethod == null) { throw new IllegalArgumentException("Unsupported database type [" + databaseType + "] for file [" + databaseFile + "]"); } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java new file mode 100644 index 0000000000000..ac7f56468f37e --- /dev/null +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookups.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest.geoip; + +import com.maxmind.db.DatabaseRecord; +import com.maxmind.db.MaxMindDbConstructor; +import com.maxmind.db.MaxMindDbParameter; +import com.maxmind.db.Reader; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.network.InetAddresses; +import org.elasticsearch.common.network.NetworkAddress; +import org.elasticsearch.core.Nullable; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +/** + * A collection of {@link IpDataLookup} implementations for IPinfo databases + */ +final class IpinfoIpDataLookups { + + private IpinfoIpDataLookups() { + // utility class + } + + private static final Logger logger = LogManager.getLogger(IpinfoIpDataLookups.class); + + /** + * Lax-ly parses a string that (ideally) looks like 'AS123' into a Long like 123L (or null, if such parsing isn't possible). + * @param asn a potentially empty (or null) ASN string that is expected to contain 'AS' and then a parsable long + * @return the parsed asn + */ + static Long parseAsn(final String asn) { + if (asn == null || Strings.hasText(asn) == false) { + return null; + } else { + String stripped = asn.toUpperCase(Locale.ROOT).replaceAll("AS", "").trim(); + try { + return Long.parseLong(stripped); + } catch (NumberFormatException e) { + logger.trace("Unable to parse non-compliant ASN string [{}]", asn); + return null; + } + } + } + + public record AsnResult( + Long asn, + @Nullable String country, // not present in the free asn database + String domain, + String name, + @Nullable String type // not present in the free asn database + ) { + @SuppressWarnings("checkstyle:RedundantModifier") + @MaxMindDbConstructor + public AsnResult( + @MaxMindDbParameter(name = "asn") String asn, + @Nullable @MaxMindDbParameter(name = "country") String country, + @MaxMindDbParameter(name = "domain") String domain, + @MaxMindDbParameter(name = "name") String name, + @Nullable @MaxMindDbParameter(name = "type") String type + ) { + this(parseAsn(asn), country, domain, name, type); + } + } + + public record CountryResult( + @MaxMindDbParameter(name = "continent") String continent, + @MaxMindDbParameter(name = "continent_name") String continentName, + @MaxMindDbParameter(name = "country") String country, + @MaxMindDbParameter(name = "country_name") String countryName + ) { + @MaxMindDbConstructor + public CountryResult {} + } + + static class Asn extends AbstractBase { + Asn(Set properties) { + super(properties, AsnResult.class); + } + + @Override + protected Map transform(final Result result) { + AsnResult response = result.result; + Long asn = response.asn; + String organizationName = response.name; + String network = result.network; + + Map data = new HashMap<>(); + for (Database.Property property : this.properties) { + switch (property) { + case IP -> data.put("ip", result.ip); + case ASN -> { + if (asn != null) { + data.put("asn", asn); + } + } + case ORGANIZATION_NAME -> { + if (organizationName != null) { + data.put("organization_name", organizationName); + } + } + case NETWORK -> { + if (network != null) { + data.put("network", network); + } + } + case COUNTRY_ISO_CODE -> { + if (response.country != null) { + data.put("country_iso_code", response.country); + } + } + case DOMAIN -> { + if (response.domain != null) { + data.put("domain", response.domain); + } + } + case TYPE -> { + if (response.type != null) { + data.put("type", response.type); + } + } + } + } + return data; + } + } + + static class Country extends AbstractBase { + Country(Set properties) { + super(properties, CountryResult.class); + } + + @Override + protected Map transform(final Result result) { + CountryResult response = result.result; + + Map data = new HashMap<>(); + for (Database.Property property : this.properties) { + switch (property) { + case IP -> data.put("ip", result.ip); + case COUNTRY_ISO_CODE -> { + String countryIsoCode = response.country; + if (countryIsoCode != null) { + data.put("country_iso_code", countryIsoCode); + } + } + case COUNTRY_NAME -> { + String countryName = response.countryName; + if (countryName != null) { + data.put("country_name", countryName); + } + } + case CONTINENT_CODE -> { + String continentCode = response.continent; + if (continentCode != null) { + data.put("continent_code", continentCode); + } + } + case CONTINENT_NAME -> { + String continentName = response.continentName; + if (continentName != null) { + data.put("continent_name", continentName); + } + } + } + } + return data; + } + } + + /** + * Just a little record holder -- there's the data that we receive via the binding to our record objects from the Reader via the + * getRecord call, but then we also need to capture the passed-in ip address that came from the caller as well as the network for + * the returned DatabaseRecord from the Reader. + */ + public record Result(T result, String ip, String network) {} + + /** + * The {@link IpinfoIpDataLookups.AbstractBase} is an abstract base implementation of {@link IpDataLookup} that + * provides common functionality for getting a {@link IpinfoIpDataLookups.Result} that wraps a record from a {@link IpDatabase}. + * + * @param the record type that will be wrapped and returned + */ + private abstract static class AbstractBase implements IpDataLookup { + + protected final Set properties; + protected final Class clazz; + + AbstractBase(final Set properties, final Class clazz) { + this.properties = Set.copyOf(properties); + this.clazz = clazz; + } + + @Override + public Set getProperties() { + return this.properties; + } + + @Override + public final Map getData(final IpDatabase ipDatabase, final String ipAddress) { + final Result response = ipDatabase.getResponse(ipAddress, this::lookup); + return (response == null || response.result == null) ? Map.of() : transform(response); + } + + @Nullable + private Result lookup(final Reader reader, final String ipAddress) throws IOException { + final InetAddress ip = InetAddresses.forString(ipAddress); + final DatabaseRecord record = reader.getRecord(ip, clazz); + final RESPONSE data = record.getData(); + return (data == null) ? null : new Result<>(data, NetworkAddress.format(ip), record.getNetwork().toString()); + } + + /** + * Extract the configured properties from the retrieved response + * @param response the non-null response that was retrieved + * @return a mapping of properties for the ip from the response + */ + protected abstract Map transform(Result response); + } +} diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java index 5b22b3f4005a9..e7c3481938033 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/MaxmindIpDataLookups.java @@ -23,6 +23,7 @@ import com.maxmind.geoip2.model.IspResponse; import com.maxmind.geoip2.record.Continent; import com.maxmind.geoip2.record.Location; +import com.maxmind.geoip2.record.Postal; import com.maxmind.geoip2.record.Subdivision; import org.elasticsearch.common.network.InetAddresses; @@ -139,11 +140,18 @@ protected Map transform(final CityResponse response) { Location location = response.getLocation(); Continent continent = response.getContinent(); Subdivision subdivision = response.getMostSpecificSubdivision(); + Postal postal = response.getPostal(); Map data = new HashMap<>(); for (Database.Property property : this.properties) { switch (property) { case IP -> data.put("ip", response.getTraits().getIpAddress()); + case COUNTRY_IN_EUROPEAN_UNION -> { + if (country.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("country_in_european_union", country.isInEuropeanUnion()); + } + } case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -206,6 +214,17 @@ protected Map transform(final CityResponse response) { data.put("location", locationObject); } } + case ACCURACY_RADIUS -> { + Integer accuracyRadius = location.getAccuracyRadius(); + if (accuracyRadius != null) { + data.put("accuracy_radius", accuracyRadius); + } + } + case POSTAL_CODE -> { + if (postal != null && postal.getCode() != null) { + data.put("postal_code", postal.getCode()); + } + } } } return data; @@ -254,6 +273,12 @@ protected Map transform(final CountryResponse response) { for (Database.Property property : this.properties) { switch (property) { case IP -> data.put("ip", response.getTraits().getIpAddress()); + case COUNTRY_IN_EUROPEAN_UNION -> { + if (country.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("country_in_european_union", country.isInEuropeanUnion()); + } + } case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -324,6 +349,7 @@ protected Map transform(final EnterpriseResponse response) { Location location = response.getLocation(); Continent continent = response.getContinent(); Subdivision subdivision = response.getMostSpecificSubdivision(); + Postal postal = response.getPostal(); Long asn = response.getTraits().getAutonomousSystemNumber(); String organizationName = response.getTraits().getAutonomousSystemOrganization(); @@ -351,6 +377,18 @@ protected Map transform(final EnterpriseResponse response) { for (Database.Property property : this.properties) { switch (property) { case IP -> data.put("ip", response.getTraits().getIpAddress()); + case COUNTRY_CONFIDENCE -> { + Integer countryConfidence = country.getConfidence(); + if (countryConfidence != null) { + data.put("country_confidence", countryConfidence); + } + } + case COUNTRY_IN_EUROPEAN_UNION -> { + if (country.getIsoCode() != null) { + // isInEuropeanUnion is a boolean so it can't be null. But it really only makes sense if we have a country + data.put("country_in_european_union", country.isInEuropeanUnion()); + } + } case COUNTRY_ISO_CODE -> { String countryIsoCode = country.getIsoCode(); if (countryIsoCode != null) { @@ -391,6 +429,12 @@ protected Map transform(final EnterpriseResponse response) { data.put("region_name", subdivisionName); } } + case CITY_CONFIDENCE -> { + Integer cityConfidence = city.getConfidence(); + if (cityConfidence != null) { + data.put("city_confidence", cityConfidence); + } + } case CITY_NAME -> { String cityName = city.getName(); if (cityName != null) { @@ -413,6 +457,23 @@ protected Map transform(final EnterpriseResponse response) { data.put("location", locationObject); } } + case ACCURACY_RADIUS -> { + Integer accuracyRadius = location.getAccuracyRadius(); + if (accuracyRadius != null) { + data.put("accuracy_radius", accuracyRadius); + } + } + case POSTAL_CODE -> { + if (postal != null && postal.getCode() != null) { + data.put("postal_code", postal.getCode()); + } + } + case POSTAL_CONFIDENCE -> { + Integer postalConfidence = postal.getConfidence(); + if (postalConfidence != null) { + data.put("postal_confidence", postalConfidence); + } + } case ASN -> { if (asn != null) { data.put("asn", asn); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java index 9972db26b3642..cfea54d2520bd 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorFactoryTests.java @@ -195,7 +195,7 @@ public void testBuildWithCountryDbAndAsnFields() { equalTo( "[properties] illegal property value [" + asnProperty - + "]. valid values are [IP, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME]" + + "]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME]" ) ); } @@ -273,8 +273,9 @@ public void testBuildIllegalFieldOption() { assertThat( e.getMessage(), equalTo( - "[properties] illegal property value [invalid]. valid values are [IP, COUNTRY_ISO_CODE, " - + "COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, CITY_NAME, TIMEZONE, LOCATION]" + "[properties] illegal property value [invalid]. valid values are [IP, COUNTRY_IN_EUROPEAN_UNION, COUNTRY_ISO_CODE, " + + "COUNTRY_NAME, CONTINENT_CODE, CONTINENT_NAME, REGION_ISO_CODE, REGION_NAME, CITY_NAME, TIMEZONE, " + + "LOCATION, POSTAL_CODE, ACCURACY_RADIUS]" ) ); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java index 793754ec316b2..ffc40324bd886 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpProcessorTests.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.ingest.IngestDocumentMatcher.assertIngestDocument; @@ -64,8 +65,16 @@ public void testDatabasePropertyInvariants() { assertThat(Sets.difference(Database.Asn.properties(), Database.Isp.properties()), is(empty())); assertThat(Sets.difference(Database.Asn.defaultProperties(), Database.Isp.defaultProperties()), is(empty())); - // the enterprise database is like everything joined together - for (Database type : Database.values()) { + // the enterprise database is like these other databases joined together + for (Database type : Set.of( + Database.City, + Database.Country, + Database.Asn, + Database.AnonymousIp, + Database.ConnectionType, + Database.Domain, + Database.Isp + )) { assertThat(Sets.difference(type.properties(), Database.Enterprise.properties()), is(empty())); } // but in terms of the default fields, it's like a drop-in replacement for the city database @@ -97,8 +106,9 @@ public void testCity() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(7)); + assertThat(geoData.size(), equalTo(9)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); @@ -213,8 +223,9 @@ public void testCity_withIpV6() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(10)); + assertThat(geoData.size(), equalTo(13)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); @@ -224,6 +235,8 @@ public void testCity_withIpV6() throws Exception { assertThat(geoData.get("city_name"), equalTo("Homestead")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 25.4573d, "lon", -80.4572d))); + assertThat(geoData.get("accuracy_radius"), equalTo(50)); + assertThat(geoData.get("postal_code"), equalTo("33035")); } public void testCityWithMissingLocation() throws Exception { @@ -278,8 +291,9 @@ public void testCountry() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(5)); + assertThat(geoData.size(), equalTo(6)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_in_european_union"), equalTo(true)); assertThat(geoData.get("country_iso_code"), equalTo("NL")); assertThat(geoData.get("country_name"), equalTo("Netherlands")); assertThat(geoData.get("continent_code"), equalTo("EU")); @@ -461,17 +475,23 @@ public void testEnterprise() throws Exception { @SuppressWarnings("unchecked") Map geoData = (Map) ingestDocument.getSourceAndMetadata().get("target_field"); assertThat(geoData, notNullValue()); - assertThat(geoData.size(), equalTo(24)); + assertThat(geoData.size(), equalTo(30)); assertThat(geoData.get("ip"), equalTo(ip)); + assertThat(geoData.get("country_confidence"), equalTo(99)); + assertThat(geoData.get("country_in_european_union"), equalTo(false)); assertThat(geoData.get("country_iso_code"), equalTo("US")); assertThat(geoData.get("country_name"), equalTo("United States")); assertThat(geoData.get("continent_code"), equalTo("NA")); assertThat(geoData.get("continent_name"), equalTo("North America")); assertThat(geoData.get("region_iso_code"), equalTo("US-NY")); assertThat(geoData.get("region_name"), equalTo("New York")); + assertThat(geoData.get("city_confidence"), equalTo(11)); assertThat(geoData.get("city_name"), equalTo("Chatham")); assertThat(geoData.get("timezone"), equalTo("America/New_York")); assertThat(geoData.get("location"), equalTo(Map.of("lat", 42.3478, "lon", -73.5549))); + assertThat(geoData.get("accuracy_radius"), equalTo(27)); + assertThat(geoData.get("postal_code"), equalTo("12037")); + assertThat(geoData.get("city_confidence"), equalTo(11)); assertThat(geoData.get("asn"), equalTo(14671L)); assertThat(geoData.get("organization_name"), equalTo("FairPoint Communications")); assertThat(geoData.get("network"), equalTo("74.209.16.0/20")); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java new file mode 100644 index 0000000000000..905eb027626a1 --- /dev/null +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/IpinfoIpDataLookupsTests.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.ingest.geoip; + +import com.maxmind.db.DatabaseRecord; +import com.maxmind.db.Networks; +import com.maxmind.db.Reader; + +import org.elasticsearch.common.network.NetworkAddress; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.watcher.ResourceWatcherService; +import org.junit.After; +import org.junit.Before; + +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.nio.file.Path; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +import static java.util.Map.entry; +import static org.elasticsearch.ingest.geoip.GeoIpTestUtils.copyDatabase; +import static org.elasticsearch.ingest.geoip.IpinfoIpDataLookups.parseAsn; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.startsWith; + +public class IpinfoIpDataLookupsTests extends ESTestCase { + + private ThreadPool threadPool; + private ResourceWatcherService resourceWatcherService; + + @Before + public void setup() { + threadPool = new TestThreadPool(ConfigDatabases.class.getSimpleName()); + Settings settings = Settings.builder().put("resource.reload.interval.high", TimeValue.timeValueMillis(100)).build(); + resourceWatcherService = new ResourceWatcherService(settings, threadPool); + } + + @After + public void cleanup() { + resourceWatcherService.close(); + threadPool.shutdownNow(); + } + + public void testDatabasePropertyInvariants() { + // the second ASN variant database is like a specialization of the ASN database + assertThat(Sets.difference(Database.Asn.properties(), Database.AsnV2.properties()), is(empty())); + assertThat(Database.Asn.defaultProperties(), equalTo(Database.AsnV2.defaultProperties())); + } + + public void testParseAsn() { + // expected case: "AS123" is 123 + assertThat(parseAsn("AS123"), equalTo(123L)); + // defensive cases: null and empty becomes null, this is not expected fwiw + assertThat(parseAsn(null), nullValue()); + assertThat(parseAsn(""), nullValue()); + // defensive cases: we strip whitespace and ignore case + assertThat(parseAsn(" as 456 "), equalTo(456L)); + // defensive cases: we ignore the absence of the 'AS' prefix + assertThat(parseAsn("123"), equalTo(123L)); + // bottom case: a non-parsable string is null + assertThat(parseAsn("anythingelse"), nullValue()); + } + + public void testAsn() throws IOException { + Path configDir = createTempDir(); + copyDatabase("ipinfo/ip_asn_sample.mmdb", configDir.resolve("ip_asn_sample.mmdb")); + copyDatabase("ipinfo/asn_sample.mmdb", configDir.resolve("asn_sample.mmdb")); + + GeoIpCache cache = new GeoIpCache(1000); // real cache to test purging of entries upon a reload + ConfigDatabases configDatabases = new ConfigDatabases(configDir, cache); + configDatabases.initialize(resourceWatcherService); + + // this is the 'free' ASN database (sample) + { + DatabaseReaderLazyLoader loader = configDatabases.getDatabase("ip_asn_sample.mmdb"); + IpDataLookup lookup = new IpinfoIpDataLookups.Asn(Set.of(Database.Property.values())); + Map data = lookup.getData(loader, "5.182.109.0"); + assertThat( + data, + equalTo( + Map.ofEntries( + entry("ip", "5.182.109.0"), + entry("organization_name", "M247 Europe SRL"), + entry("asn", 9009L), + entry("network", "5.182.109.0/24"), + entry("domain", "m247.com") + ) + ) + ); + } + + // this is the non-free or 'standard' ASN database (sample) + { + DatabaseReaderLazyLoader loader = configDatabases.getDatabase("asn_sample.mmdb"); + IpDataLookup lookup = new IpinfoIpDataLookups.Asn(Set.of(Database.Property.values())); + Map data = lookup.getData(loader, "23.53.116.0"); + assertThat( + data, + equalTo( + Map.ofEntries( + entry("ip", "23.53.116.0"), + entry("organization_name", "Akamai Technologies, Inc."), + entry("asn", 32787L), + entry("network", "23.53.116.0/24"), + entry("domain", "akamai.com"), + entry("type", "hosting"), + entry("country_iso_code", "US") + ) + ) + ); + } + } + + public void testAsnInvariants() { + Path configDir = createTempDir(); + copyDatabase("ipinfo/ip_asn_sample.mmdb", configDir.resolve("ip_asn_sample.mmdb")); + copyDatabase("ipinfo/asn_sample.mmdb", configDir.resolve("asn_sample.mmdb")); + + { + final Set expectedColumns = Set.of("network", "asn", "name", "domain"); + + Path databasePath = configDir.resolve("ip_asn_sample.mmdb"); + assertDatabaseInvariants(databasePath, (ip, row) -> { + assertThat(row.keySet(), equalTo(expectedColumns)); + String asn = (String) row.get("asn"); + assertThat(asn, startsWith("AS")); + assertThat(asn, equalTo(asn.trim())); + Long parsed = parseAsn(asn); + assertThat(parsed, notNullValue()); + assertThat(asn, equalTo("AS" + parsed)); // reverse it + }); + } + + { + final Set expectedColumns = Set.of("network", "asn", "name", "domain", "country", "type"); + + Path databasePath = configDir.resolve("asn_sample.mmdb"); + assertDatabaseInvariants(databasePath, (ip, row) -> { + assertThat(row.keySet(), equalTo(expectedColumns)); + String asn = (String) row.get("asn"); + assertThat(asn, startsWith("AS")); + assertThat(asn, equalTo(asn.trim())); + Long parsed = parseAsn(asn); + assertThat(parsed, notNullValue()); + assertThat(asn, equalTo("AS" + parsed)); // reverse it + }); + } + } + + public void testCountry() throws IOException { + Path configDir = createTempDir(); + copyDatabase("ipinfo/ip_country_sample.mmdb", configDir.resolve("ip_country_sample.mmdb")); + + GeoIpCache cache = new GeoIpCache(1000); // real cache to test purging of entries upon a reload + ConfigDatabases configDatabases = new ConfigDatabases(configDir, cache); + configDatabases.initialize(resourceWatcherService); + + // this is the 'free' Country database (sample) + { + DatabaseReaderLazyLoader loader = configDatabases.getDatabase("ip_country_sample.mmdb"); + IpDataLookup lookup = new IpinfoIpDataLookups.Country(Set.of(Database.Property.values())); + Map data = lookup.getData(loader, "4.221.143.168"); + assertThat( + data, + equalTo( + Map.ofEntries( + entry("ip", "4.221.143.168"), + entry("country_name", "South Africa"), + entry("country_iso_code", "ZA"), + entry("continent_name", "Africa"), + entry("continent_code", "AF") + ) + ) + ); + } + } + + private static void assertDatabaseInvariants(final Path databasePath, final BiConsumer> rowConsumer) { + try (Reader reader = new Reader(pathToFile(databasePath))) { + Networks networks = reader.networks(Map.class); + while (networks.hasNext()) { + DatabaseRecord dbr = networks.next(); + InetAddress address = dbr.getNetwork().getNetworkAddress(); + @SuppressWarnings("unchecked") + Map result = reader.get(address, Map.class); + try { + rowConsumer.accept(address, result); + } catch (AssertionError e) { + fail(e, "Assert failed for address [%s]", NetworkAddress.format(address)); + } catch (Exception e) { + fail(e, "Exception handling address [%s]", NetworkAddress.format(address)); + } + } + } catch (Exception e) { + fail(e); + } + } + + @SuppressForbidden(reason = "Maxmind API requires java.io.File") + private static File pathToFile(Path databasePath) { + return databasePath.toFile(); + } +} diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java index 84ea5fd584352..1e05cf2b3ba33 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/MaxMindSupportTests.java @@ -78,13 +78,16 @@ public class MaxMindSupportTests extends ESTestCase { "city.name", "continent.code", "continent.name", + "country.inEuropeanUnion", "country.isoCode", "country.name", + "location.accuracyRadius", "location.latitude", "location.longitude", "location.timeZone", "mostSpecificSubdivision.isoCode", - "mostSpecificSubdivision.name" + "mostSpecificSubdivision.name", + "postal.code" ); private static final Set CITY_UNSUPPORTED_FIELDS = Set.of( "city.confidence", @@ -94,14 +97,12 @@ public class MaxMindSupportTests extends ESTestCase { "continent.names", "country.confidence", "country.geoNameId", - "country.inEuropeanUnion", "country.names", "leastSpecificSubdivision.confidence", "leastSpecificSubdivision.geoNameId", "leastSpecificSubdivision.isoCode", "leastSpecificSubdivision.name", "leastSpecificSubdivision.names", - "location.accuracyRadius", "location.averageIncome", "location.metroCode", "location.populationDensity", @@ -109,7 +110,6 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.confidence", "mostSpecificSubdivision.geoNameId", "mostSpecificSubdivision.names", - "postal.code", "postal.confidence", "registeredCountry.confidence", "registeredCountry.geoNameId", @@ -159,6 +159,7 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set COUNTRY_SUPPORTED_FIELDS = Set.of( "continent.name", + "country.inEuropeanUnion", "country.isoCode", "continent.code", "country.name" @@ -168,7 +169,6 @@ public class MaxMindSupportTests extends ESTestCase { "continent.names", "country.confidence", "country.geoNameId", - "country.inEuropeanUnion", "country.names", "maxMind", "registeredCountry.confidence", @@ -213,16 +213,22 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set DOMAIN_UNSUPPORTED_FIELDS = Set.of("ipAddress", "network"); private static final Set ENTERPRISE_SUPPORTED_FIELDS = Set.of( + "city.confidence", "city.name", "continent.code", "continent.name", + "country.confidence", + "country.inEuropeanUnion", "country.isoCode", "country.name", + "location.accuracyRadius", "location.latitude", "location.longitude", "location.timeZone", "mostSpecificSubdivision.isoCode", "mostSpecificSubdivision.name", + "postal.code", + "postal.confidence", "traits.anonymous", "traits.anonymousVpn", "traits.autonomousSystemNumber", @@ -241,21 +247,17 @@ public class MaxMindSupportTests extends ESTestCase { "traits.userType" ); private static final Set ENTERPRISE_UNSUPPORTED_FIELDS = Set.of( - "city.confidence", "city.geoNameId", "city.names", "continent.geoNameId", "continent.names", - "country.confidence", "country.geoNameId", - "country.inEuropeanUnion", "country.names", "leastSpecificSubdivision.confidence", "leastSpecificSubdivision.geoNameId", "leastSpecificSubdivision.isoCode", "leastSpecificSubdivision.name", "leastSpecificSubdivision.names", - "location.accuracyRadius", "location.averageIncome", "location.metroCode", "location.populationDensity", @@ -263,8 +265,6 @@ public class MaxMindSupportTests extends ESTestCase { "mostSpecificSubdivision.confidence", "mostSpecificSubdivision.geoNameId", "mostSpecificSubdivision.names", - "postal.code", - "postal.confidence", "registeredCountry.confidence", "registeredCountry.geoNameId", "registeredCountry.inEuropeanUnion", @@ -361,8 +361,14 @@ public class MaxMindSupportTests extends ESTestCase { private static final Set> KNOWN_UNSUPPORTED_RESPONSE_CLASSES = Set.of(IpRiskResponse.class); + private static final Set KNOWN_UNSUPPORTED_DATABASE_VARIANTS = Set.of(Database.AsnV2); + public void testMaxMindSupport() { for (Database databaseType : Database.values()) { + if (KNOWN_UNSUPPORTED_DATABASE_VARIANTS.contains(databaseType)) { + continue; + } + Class maxMindClass = TYPE_TO_MAX_MIND_CLASS.get(databaseType); Set supportedFields = TYPE_TO_SUPPORTED_FIELDS_MAP.get(databaseType); Set unsupportedFields = TYPE_TO_UNSUPPORTED_FIELDS_MAP.get(databaseType); diff --git a/modules/ingest-geoip/src/test/resources/ipinfo/asn_sample.mmdb b/modules/ingest-geoip/src/test/resources/ipinfo/asn_sample.mmdb new file mode 100644 index 0000000000000..916a8252a5df1 Binary files /dev/null and b/modules/ingest-geoip/src/test/resources/ipinfo/asn_sample.mmdb differ diff --git a/modules/ingest-geoip/src/test/resources/ipinfo/ip_asn_sample.mmdb b/modules/ingest-geoip/src/test/resources/ipinfo/ip_asn_sample.mmdb new file mode 100644 index 0000000000000..3e1fc49ba48a5 Binary files /dev/null and b/modules/ingest-geoip/src/test/resources/ipinfo/ip_asn_sample.mmdb differ diff --git a/modules/ingest-geoip/src/test/resources/ipinfo/ip_country_sample.mmdb b/modules/ingest-geoip/src/test/resources/ipinfo/ip_country_sample.mmdb new file mode 100644 index 0000000000000..88428315ee8d6 Binary files /dev/null and b/modules/ingest-geoip/src/test/resources/ipinfo/ip_country_sample.mmdb differ diff --git a/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryMetricsTests.java b/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryMetricsTests.java new file mode 100644 index 0000000000000..a9bf0afa37e18 --- /dev/null +++ b/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryMetricsTests.java @@ -0,0 +1,468 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.repositories.azure; + +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; + +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.BlobPath; +import org.elasticsearch.common.blobstore.OperationPurpose; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.repositories.RepositoriesMetrics; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.repositories.blobstore.RequestedRangeNotSatisfiedException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.telemetry.Measurement; +import org.elasticsearch.telemetry.TestTelemetryPlugin; +import org.junit.After; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.elasticsearch.repositories.azure.AbstractAzureServerTestCase.randomBlobContent; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +@SuppressForbidden(reason = "we use a HttpServer to emulate Azure") +public class AzureBlobStoreRepositoryMetricsTests extends AzureBlobStoreRepositoryTests { + + private static final Predicate GET_BLOB_REQUEST_PREDICATE = request -> GET_BLOB_PATTERN.test( + request.getRequestMethod() + " " + request.getRequestURI() + ); + private static final int MAX_RETRIES = 3; + + private final Queue requestHandlers = new ConcurrentLinkedQueue<>(); + + @Override + protected Map createHttpHandlers() { + Map httpHandlers = super.createHttpHandlers(); + assert httpHandlers.size() == 1 : "This assumes there's a single handler"; + return httpHandlers.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> new ResponseInjectingAzureHttpHandler(requestHandlers, e.getValue()))); + } + + /** + * We want to control the errors in this test + */ + @Override + protected HttpHandler createErroneousHttpHandler(HttpHandler delegate) { + return delegate; + } + + @After + public void checkRequestHandlerQueue() { + if (requestHandlers.isEmpty() == false) { + fail("There were unused request handlers left in the queue, this is probably a broken test"); + } + } + + private static BlobContainer getBlobContainer(String dataNodeName, String repository) { + final var blobStoreRepository = (BlobStoreRepository) internalCluster().getInstance(RepositoriesService.class, dataNodeName) + .repository(repository); + return blobStoreRepository.blobStore().blobContainer(BlobPath.EMPTY.add(randomIdentifier())); + } + + public void testThrottleResponsesAreCountedInMetrics() throws IOException { + final String repository = createRepository(randomRepositoryName()); + final String dataNodeName = internalCluster().getNodeNameThat(DiscoveryNode::canContainData); + final BlobContainer blobContainer = getBlobContainer(dataNodeName, repository); + + // Create a blob + final String blobName = "index-" + randomIdentifier(); + final OperationPurpose purpose = randomFrom(OperationPurpose.values()); + blobContainer.writeBlob(purpose, blobName, BytesReference.fromByteBuffer(ByteBuffer.wrap(randomBlobContent())), false); + clearMetrics(dataNodeName); + + // Queue up some throttle responses + final int numThrottles = randomIntBetween(1, MAX_RETRIES); + IntStream.range(0, numThrottles).forEach(i -> requestHandlers.offer(new FixedRequestHandler(RestStatus.TOO_MANY_REQUESTS))); + + // Check that the blob exists + blobContainer.blobExists(purpose, blobName); + + // Correct metrics are recorded + metricsAsserter(dataNodeName, purpose, AzureBlobStore.Operation.GET_BLOB_PROPERTIES, repository).expectMetrics() + .withRequests(numThrottles + 1) + .withThrottles(numThrottles) + .withExceptions(numThrottles) + .forResult(MetricsAsserter.Result.Success); + } + + public void testRangeNotSatisfiedAreCountedInMetrics() throws IOException { + final String repository = createRepository(randomRepositoryName()); + final String dataNodeName = internalCluster().getNodeNameThat(DiscoveryNode::canContainData); + final BlobContainer blobContainer = getBlobContainer(dataNodeName, repository); + + // Create a blob + final String blobName = "index-" + randomIdentifier(); + final OperationPurpose purpose = randomFrom(OperationPurpose.values()); + blobContainer.writeBlob(purpose, blobName, BytesReference.fromByteBuffer(ByteBuffer.wrap(randomBlobContent())), false); + clearMetrics(dataNodeName); + + // Queue up a range-not-satisfied error + requestHandlers.offer(new FixedRequestHandler(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, null, GET_BLOB_REQUEST_PREDICATE)); + + // Attempt to read the blob + assertThrows(RequestedRangeNotSatisfiedException.class, () -> blobContainer.readBlob(purpose, blobName)); + + // Correct metrics are recorded + metricsAsserter(dataNodeName, purpose, AzureBlobStore.Operation.GET_BLOB, repository).expectMetrics() + .withRequests(1) + .withThrottles(0) + .withExceptions(1) + .forResult(MetricsAsserter.Result.RangeNotSatisfied); + } + + public void testErrorResponsesAreCountedInMetrics() throws IOException { + final String repository = createRepository(randomRepositoryName()); + final String dataNodeName = internalCluster().getNodeNameThat(DiscoveryNode::canContainData); + final BlobContainer blobContainer = getBlobContainer(dataNodeName, repository); + + // Create a blob + final String blobName = "index-" + randomIdentifier(); + final OperationPurpose purpose = randomFrom(OperationPurpose.values()); + blobContainer.writeBlob(purpose, blobName, BytesReference.fromByteBuffer(ByteBuffer.wrap(randomBlobContent())), false); + clearMetrics(dataNodeName); + + // Queue some retry-able error responses + final int numErrors = randomIntBetween(1, MAX_RETRIES); + final AtomicInteger throttles = new AtomicInteger(); + IntStream.range(0, numErrors).forEach(i -> { + RestStatus status = randomFrom(RestStatus.INTERNAL_SERVER_ERROR, RestStatus.TOO_MANY_REQUESTS, RestStatus.SERVICE_UNAVAILABLE); + if (status == RestStatus.TOO_MANY_REQUESTS) { + throttles.incrementAndGet(); + } + requestHandlers.offer(new FixedRequestHandler(status)); + }); + + // Check that the blob exists + blobContainer.blobExists(purpose, blobName); + + // Correct metrics are recorded + metricsAsserter(dataNodeName, purpose, AzureBlobStore.Operation.GET_BLOB_PROPERTIES, repository).expectMetrics() + .withRequests(numErrors + 1) + .withThrottles(throttles.get()) + .withExceptions(numErrors) + .forResult(MetricsAsserter.Result.Success); + } + + public void testRequestFailuresAreCountedInMetrics() { + final String repository = createRepository(randomRepositoryName()); + final String dataNodeName = internalCluster().getNodeNameThat(DiscoveryNode::canContainData); + final BlobContainer blobContainer = getBlobContainer(dataNodeName, repository); + clearMetrics(dataNodeName); + + // Repeatedly cause a connection error to exhaust retries + IntStream.range(0, MAX_RETRIES + 1).forEach(i -> requestHandlers.offer((exchange, delegate) -> exchange.close())); + + // Hit the API + OperationPurpose purpose = randomFrom(OperationPurpose.values()); + assertThrows(IOException.class, () -> blobContainer.listBlobs(purpose)); + + // Correct metrics are recorded + metricsAsserter(dataNodeName, purpose, AzureBlobStore.Operation.LIST_BLOBS, repository).expectMetrics() + .withRequests(4) + .withThrottles(0) + .withExceptions(4) + .forResult(MetricsAsserter.Result.Exception); + } + + public void testRequestTimeIsAccurate() throws IOException { + final String repository = createRepository(randomRepositoryName()); + final String dataNodeName = internalCluster().getNodeNameThat(DiscoveryNode::canContainData); + final BlobContainer blobContainer = getBlobContainer(dataNodeName, repository); + clearMetrics(dataNodeName); + + AtomicLong totalDelayMillis = new AtomicLong(0); + // Add some artificial delays + IntStream.range(0, randomIntBetween(1, MAX_RETRIES)).forEach(i -> { + long thisDelay = randomLongBetween(10, 100); + totalDelayMillis.addAndGet(thisDelay); + requestHandlers.offer((exchange, delegate) -> { + safeSleep(thisDelay); + // return a retry-able error + exchange.sendResponseHeaders(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), -1); + }); + }); + + // Hit the API + final long startTimeMillis = System.currentTimeMillis(); + blobContainer.listBlobs(randomFrom(OperationPurpose.values())); + final long elapsedTimeMillis = System.currentTimeMillis() - startTimeMillis; + + List longHistogramMeasurement = getTelemetryPlugin(dataNodeName).getLongHistogramMeasurement( + RepositoriesMetrics.HTTP_REQUEST_TIME_IN_MILLIS_HISTOGRAM + ); + long recordedRequestTime = longHistogramMeasurement.get(0).getLong(); + // Request time should be >= the delays we simulated + assertThat(recordedRequestTime, greaterThanOrEqualTo(totalDelayMillis.get())); + // And <= the elapsed time for the request + assertThat(recordedRequestTime, lessThanOrEqualTo(elapsedTimeMillis)); + } + + private void clearMetrics(String discoveryNode) { + internalCluster().getInstance(PluginsService.class, discoveryNode) + .filterPlugins(TestTelemetryPlugin.class) + .forEach(TestTelemetryPlugin::resetMeter); + } + + private MetricsAsserter metricsAsserter( + String dataNodeName, + OperationPurpose operationPurpose, + AzureBlobStore.Operation operation, + String repository + ) { + return new MetricsAsserter(dataNodeName, operationPurpose, operation, repository); + } + + private class MetricsAsserter { + private final String dataNodeName; + private final OperationPurpose purpose; + private final AzureBlobStore.Operation operation; + private final String repository; + + enum Result { + Success, + Failure, + RangeNotSatisfied, + Exception + } + + enum MetricType { + LongHistogram { + @Override + List getMeasurements(TestTelemetryPlugin testTelemetryPlugin, String name) { + return testTelemetryPlugin.getLongHistogramMeasurement(name); + } + }, + LongCounter { + @Override + List getMeasurements(TestTelemetryPlugin testTelemetryPlugin, String name) { + return testTelemetryPlugin.getLongCounterMeasurement(name); + } + }; + + abstract List getMeasurements(TestTelemetryPlugin testTelemetryPlugin, String name); + } + + private MetricsAsserter(String dataNodeName, OperationPurpose purpose, AzureBlobStore.Operation operation, String repository) { + this.dataNodeName = dataNodeName; + this.purpose = purpose; + this.operation = operation; + this.repository = repository; + } + + private class Expectations { + private int expectedRequests; + private int expectedThrottles; + private int expectedExceptions; + + public Expectations withRequests(int expectedRequests) { + this.expectedRequests = expectedRequests; + return this; + } + + public Expectations withThrottles(int expectedThrottles) { + this.expectedThrottles = expectedThrottles; + return this; + } + + public Expectations withExceptions(int expectedExceptions) { + this.expectedExceptions = expectedExceptions; + return this; + } + + public void forResult(Result result) { + assertMetricsRecorded(expectedRequests, expectedThrottles, expectedExceptions, result); + } + } + + Expectations expectMetrics() { + return new Expectations(); + } + + private void assertMetricsRecorded(int expectedRequests, int expectedThrottles, int expectedExceptions, Result result) { + assertIntMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_OPERATIONS_TOTAL, 1); + assertIntMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_REQUESTS_TOTAL, expectedRequests); + + if (expectedThrottles > 0) { + assertIntMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_THROTTLES_TOTAL, expectedThrottles); + assertIntMetricRecorded(MetricType.LongHistogram, RepositoriesMetrics.METRIC_THROTTLES_HISTOGRAM, expectedThrottles); + } else { + assertNoMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_THROTTLES_TOTAL); + assertNoMetricRecorded(MetricType.LongHistogram, RepositoriesMetrics.METRIC_THROTTLES_HISTOGRAM); + } + + if (expectedExceptions > 0) { + assertIntMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_EXCEPTIONS_TOTAL, expectedExceptions); + assertIntMetricRecorded(MetricType.LongHistogram, RepositoriesMetrics.METRIC_EXCEPTIONS_HISTOGRAM, expectedExceptions); + } else { + assertNoMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_EXCEPTIONS_TOTAL); + assertNoMetricRecorded(MetricType.LongHistogram, RepositoriesMetrics.METRIC_EXCEPTIONS_HISTOGRAM); + } + + if (result == Result.RangeNotSatisfied || result == Result.Failure || result == Result.Exception) { + assertIntMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_UNSUCCESSFUL_OPERATIONS_TOTAL, 1); + } else { + assertNoMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_UNSUCCESSFUL_OPERATIONS_TOTAL); + } + + if (result == Result.RangeNotSatisfied) { + assertIntMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_EXCEPTIONS_REQUEST_RANGE_NOT_SATISFIED_TOTAL, 1); + } else { + assertNoMetricRecorded(MetricType.LongCounter, RepositoriesMetrics.METRIC_EXCEPTIONS_REQUEST_RANGE_NOT_SATISFIED_TOTAL); + } + + assertMatchingMetricRecorded( + MetricType.LongHistogram, + RepositoriesMetrics.HTTP_REQUEST_TIME_IN_MILLIS_HISTOGRAM, + m -> assertThat("No request time metric found", m.getLong(), greaterThanOrEqualTo(0L)) + ); + } + + private void assertIntMetricRecorded(MetricType metricType, String metricName, int expectedValue) { + assertMatchingMetricRecorded( + metricType, + metricName, + measurement -> assertEquals("Unexpected value for " + metricType + " " + metricName, expectedValue, measurement.getLong()) + ); + } + + private void assertNoMetricRecorded(MetricType metricType, String metricName) { + assertThat( + "Expected no values for " + metricType + " " + metricName, + metricType.getMeasurements(getTelemetryPlugin(dataNodeName), metricName), + hasSize(0) + ); + } + + private void assertMatchingMetricRecorded(MetricType metricType, String metricName, Consumer assertion) { + List measurements = metricType.getMeasurements(getTelemetryPlugin(dataNodeName), metricName); + Measurement measurement = measurements.stream() + .filter( + m -> m.attributes().get("operation").equals(operation.getKey()) + && m.attributes().get("purpose").equals(purpose.getKey()) + && m.attributes().get("repo_name").equals(repository) + && m.attributes().get("repo_type").equals("azure") + ) + .findFirst() + .orElseThrow( + () -> new IllegalStateException( + "No metric found with name=" + + metricName + + " and operation=" + + operation.getKey() + + " and purpose=" + + purpose.getKey() + + " and repo_name=" + + repository + + " in " + + measurements + ) + ); + + assertion.accept(measurement); + } + } + + @SuppressForbidden(reason = "we use a HttpServer to emulate Azure") + private static class ResponseInjectingAzureHttpHandler implements DelegatingHttpHandler { + + private final HttpHandler delegate; + private final Queue requestHandlerQueue; + + ResponseInjectingAzureHttpHandler(Queue requestHandlerQueue, HttpHandler delegate) { + this.delegate = delegate; + this.requestHandlerQueue = requestHandlerQueue; + } + + @Override + public void handle(HttpExchange exchange) throws IOException { + RequestHandler nextHandler = requestHandlerQueue.peek(); + if (nextHandler != null && nextHandler.matchesRequest(exchange)) { + requestHandlerQueue.poll().writeResponse(exchange, delegate); + } else { + delegate.handle(exchange); + } + } + + @Override + public HttpHandler getDelegate() { + return delegate; + } + } + + @SuppressForbidden(reason = "we use a HttpServer to emulate Azure") + @FunctionalInterface + private interface RequestHandler { + void writeResponse(HttpExchange exchange, HttpHandler delegate) throws IOException; + + default boolean matchesRequest(HttpExchange exchange) { + return true; + } + } + + @SuppressForbidden(reason = "we use a HttpServer to emulate Azure") + private static class FixedRequestHandler implements RequestHandler { + + private final RestStatus status; + private final String responseBody; + private final Predicate requestMatcher; + + FixedRequestHandler(RestStatus status) { + this(status, null, req -> true); + } + + /** + * Create a handler that only gets executed for requests that match the supplied predicate. Note + * that because the errors are stored in a queue this will prevent any subsequently queued errors from + * being returned until after it returns. + */ + FixedRequestHandler(RestStatus status, String responseBody, Predicate requestMatcher) { + this.status = status; + this.responseBody = responseBody; + this.requestMatcher = requestMatcher; + } + + @Override + public boolean matchesRequest(HttpExchange exchange) { + return requestMatcher.test(exchange); + } + + @Override + public void writeResponse(HttpExchange exchange, HttpHandler delegateHandler) throws IOException { + if (responseBody != null) { + byte[] responseBytes = responseBody.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(status.getStatus(), responseBytes.length); + exchange.getResponseBody().write(responseBytes); + } else { + exchange.sendResponseHeaders(status.getStatus(), -1); + } + } + } +} diff --git a/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java b/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java index 1b7628cc0ad8e..473d91da6e34c 100644 --- a/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java +++ b/modules/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java @@ -16,11 +16,13 @@ import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpHandler; +import org.elasticsearch.action.support.broadcast.BroadcastResponse; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; +import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.MockSecureSettings; @@ -30,8 +32,15 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.repositories.RepositoryMissingException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.repositories.blobstore.ESMockAPIBasedRepositoryIntegTestCase; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.telemetry.Measurement; +import org.elasticsearch.telemetry.TestTelemetryPlugin; +import org.elasticsearch.test.BackgroundIndexer; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -41,22 +50,33 @@ import java.util.Base64; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.LongAdder; import java.util.function.Predicate; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import static org.elasticsearch.repositories.RepositoriesMetrics.METRIC_OPERATIONS_TOTAL; import static org.elasticsearch.repositories.blobstore.BlobStoreTestUtil.randomPurpose; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.is; @SuppressForbidden(reason = "this test uses a HttpServer to emulate an Azure endpoint") public class AzureBlobStoreRepositoryTests extends ESMockAPIBasedRepositoryIntegTestCase { - private static final String DEFAULT_ACCOUNT_NAME = "account"; + protected static final String DEFAULT_ACCOUNT_NAME = "account"; + protected static final Predicate LIST_PATTERN = Pattern.compile("GET /[a-zA-Z0-9]+/[a-zA-Z0-9]+\\?.+").asMatchPredicate(); + protected static final Predicate GET_BLOB_PATTERN = Pattern.compile("GET /[a-zA-Z0-9]+/[a-zA-Z0-9]+/.+").asMatchPredicate(); @Override protected String repositoryType() { @@ -78,7 +98,7 @@ protected Settings repositorySettings(String repoName) { @Override protected Collection> nodePlugins() { - return Collections.singletonList(TestAzureRepositoryPlugin.class); + return List.of(TestAzureRepositoryPlugin.class, TestTelemetryPlugin.class); } @Override @@ -91,7 +111,7 @@ protected Map createHttpHandlers() { @Override protected HttpHandler createErroneousHttpHandler(final HttpHandler delegate) { - return new AzureErroneousHttpHandler(delegate, AzureStorageSettings.DEFAULT_MAX_RETRIES); + return new AzureHTTPStatsCollectorHandler(new AzureErroneousHttpHandler(delegate, AzureStorageSettings.DEFAULT_MAX_RETRIES)); } @Override @@ -119,6 +139,13 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { .build(); } + protected TestTelemetryPlugin getTelemetryPlugin(String dataNodeName) { + return internalCluster().getInstance(PluginsService.class, dataNodeName) + .filterPlugins(TestTelemetryPlugin.class) + .findFirst() + .orElseThrow(); + } + /** * AzureRepositoryPlugin that allows to set low values for the Azure's client retry policy * and for BlobRequestOptions#getSingleBlobPutThresholdInBytes(). @@ -195,9 +222,6 @@ protected String requestUniqueId(final HttpExchange exchange) { */ @SuppressForbidden(reason = "this test uses a HttpServer to emulate an Azure endpoint") private static class AzureHTTPStatsCollectorHandler extends HttpStatsCollectorHandler { - private static final Predicate LIST_PATTERN = Pattern.compile("GET /[a-zA-Z0-9]+/[a-zA-Z0-9]+\\?.+").asMatchPredicate(); - private static final Predicate GET_BLOB_PATTERN = Pattern.compile("GET /[a-zA-Z0-9]+/[a-zA-Z0-9]+/.+").asMatchPredicate(); - private final Set seenRequestIds = ConcurrentCollections.newConcurrentSet(); private AzureHTTPStatsCollectorHandler(HttpHandler delegate) { @@ -303,4 +327,87 @@ public void testReadByteByByte() throws Exception { container.delete(randomPurpose()); } } + + public void testMetrics() throws Exception { + // Reset all the metrics so there's none lingering from previous tests + internalCluster().getInstances(PluginsService.class) + .forEach(ps -> ps.filterPlugins(TestTelemetryPlugin.class).forEach(TestTelemetryPlugin::resetMeter)); + + // Create the repository and perform some activities + final String repository = createRepository(randomRepositoryName(), false); + final String index = "index-no-merges"; + createIndex(index, 1, 0); + + final long nbDocs = randomLongBetween(10_000L, 20_000L); + try (BackgroundIndexer indexer = new BackgroundIndexer(index, client(), (int) nbDocs)) { + waitForDocs(nbDocs, indexer); + } + flushAndRefresh(index); + BroadcastResponse forceMerge = client().admin().indices().prepareForceMerge(index).setFlush(true).setMaxNumSegments(1).get(); + assertThat(forceMerge.getSuccessfulShards(), equalTo(1)); + assertHitCount(prepareSearch(index).setSize(0).setTrackTotalHits(true), nbDocs); + + final String snapshot = "snapshot"; + assertSuccessfulSnapshot( + clusterAdmin().prepareCreateSnapshot(TEST_REQUEST_TIMEOUT, repository, snapshot).setWaitForCompletion(true).setIndices(index) + ); + assertAcked(client().admin().indices().prepareDelete(index)); + assertSuccessfulRestore( + clusterAdmin().prepareRestoreSnapshot(TEST_REQUEST_TIMEOUT, repository, snapshot).setWaitForCompletion(true) + ); + ensureGreen(index); + assertHitCount(prepareSearch(index).setSize(0).setTrackTotalHits(true), nbDocs); + assertAcked(clusterAdmin().prepareDeleteSnapshot(TEST_REQUEST_TIMEOUT, repository, snapshot).get()); + + final Map aggregatedMetrics = new HashMap<>(); + // Compare collected stats and metrics for each node and they should be the same + for (var nodeName : internalCluster().getNodeNames()) { + final BlobStoreRepository blobStoreRepository; + try { + blobStoreRepository = (BlobStoreRepository) internalCluster().getInstance(RepositoriesService.class, nodeName) + .repository(repository); + } catch (RepositoryMissingException e) { + continue; + } + + final AzureBlobStore blobStore = (AzureBlobStore) blobStoreRepository.blobStore(); + final Map statsCollectors = blobStore.getMetricsRecorder().opsCounters; + + final List metrics = Measurement.combine( + getTelemetryPlugin(nodeName).getLongCounterMeasurement(METRIC_OPERATIONS_TOTAL) + ); + + assertThat( + statsCollectors.keySet().stream().map(AzureBlobStore.StatsKey::operation).collect(Collectors.toSet()), + equalTo( + metrics.stream() + .map(m -> AzureBlobStore.Operation.fromKey((String) m.attributes().get("operation"))) + .collect(Collectors.toSet()) + ) + ); + metrics.forEach(metric -> { + assertThat( + metric.attributes(), + allOf(hasEntry("repo_type", AzureRepository.TYPE), hasKey("repo_name"), hasKey("operation"), hasKey("purpose")) + ); + final AzureBlobStore.Operation operation = AzureBlobStore.Operation.fromKey((String) metric.attributes().get("operation")); + final AzureBlobStore.StatsKey statsKey = new AzureBlobStore.StatsKey( + operation, + OperationPurpose.parse((String) metric.attributes().get("purpose")) + ); + assertThat(nodeName + "/" + statsKey + " exists", statsCollectors, hasKey(statsKey)); + assertThat(nodeName + "/" + statsKey + " has correct sum", metric.getLong(), equalTo(statsCollectors.get(statsKey).sum())); + aggregatedMetrics.compute(statsKey.operation(), (k, v) -> v == null ? metric.getLong() : v + metric.getLong()); + }); + } + + // Metrics number should be consistent with server side request count as well. + assertThat(aggregatedMetrics, equalTo(getServerMetrics())); + } + + private Map getServerMetrics() { + return getMockRequestCounts().entrySet() + .stream() + .collect(Collectors.toMap(e -> AzureBlobStore.Operation.fromKey(e.getKey()), Map.Entry::getValue)); + } } diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java index 5466989082129..d520d30f2bac6 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java @@ -60,6 +60,7 @@ import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; +import org.elasticsearch.repositories.RepositoriesMetrics; import org.elasticsearch.repositories.azure.AzureRepository.Repository; import org.elasticsearch.repositories.blobstore.ChunkedBlobOutputStream; import org.elasticsearch.rest.RestStatus; @@ -86,11 +87,11 @@ import java.util.Spliterator; import java.util.Spliterators; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.LongAdder; import java.util.function.BiPredicate; -import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -102,59 +103,54 @@ public class AzureBlobStore implements BlobStore { private static final int DEFAULT_UPLOAD_BUFFERS_SIZE = (int) new ByteSizeValue(64, ByteSizeUnit.KB).getBytes(); private final AzureStorageService service; - private final BigArrays bigArrays; + private final RepositoryMetadata repositoryMetadata; private final String clientName; private final String container; private final LocationMode locationMode; private final ByteSizeValue maxSinglePartUploadSize; - private final StatsCollectors statsCollectors = new StatsCollectors(); - private final AzureClientProvider.SuccessfulRequestHandler statsConsumer; + private final RequestMetricsRecorder requestMetricsRecorder; + private final AzureClientProvider.RequestMetricsHandler requestMetricsHandler; - public AzureBlobStore(RepositoryMetadata metadata, AzureStorageService service, BigArrays bigArrays) { + public AzureBlobStore( + RepositoryMetadata metadata, + AzureStorageService service, + BigArrays bigArrays, + RepositoriesMetrics repositoriesMetrics + ) { this.container = Repository.CONTAINER_SETTING.get(metadata.settings()); this.clientName = Repository.CLIENT_NAME.get(metadata.settings()); this.service = service; this.bigArrays = bigArrays; + this.requestMetricsRecorder = new RequestMetricsRecorder(repositoriesMetrics); + this.repositoryMetadata = metadata; // locationMode is set per repository, not per client this.locationMode = Repository.LOCATION_MODE_SETTING.get(metadata.settings()); this.maxSinglePartUploadSize = Repository.MAX_SINGLE_PART_UPLOAD_SIZE_SETTING.get(metadata.settings()); - List requestStatsCollectors = List.of( - RequestStatsCollector.create( - (httpMethod, url) -> httpMethod == HttpMethod.HEAD, - purpose -> statsCollectors.onSuccessfulRequest(Operation.GET_BLOB_PROPERTIES, purpose) - ), - RequestStatsCollector.create( + List requestMatchers = List.of( + new RequestMatcher((httpMethod, url) -> httpMethod == HttpMethod.HEAD, Operation.GET_BLOB_PROPERTIES), + new RequestMatcher( (httpMethod, url) -> httpMethod == HttpMethod.GET && isListRequest(httpMethod, url) == false, - purpose -> statsCollectors.onSuccessfulRequest(Operation.GET_BLOB, purpose) - ), - RequestStatsCollector.create( - AzureBlobStore::isListRequest, - purpose -> statsCollectors.onSuccessfulRequest(Operation.LIST_BLOBS, purpose) - ), - RequestStatsCollector.create( - AzureBlobStore::isPutBlockRequest, - purpose -> statsCollectors.onSuccessfulRequest(Operation.PUT_BLOCK, purpose) + Operation.GET_BLOB ), - RequestStatsCollector.create( - AzureBlobStore::isPutBlockListRequest, - purpose -> statsCollectors.onSuccessfulRequest(Operation.PUT_BLOCK_LIST, purpose) - ), - RequestStatsCollector.create( + new RequestMatcher(AzureBlobStore::isListRequest, Operation.LIST_BLOBS), + new RequestMatcher(AzureBlobStore::isPutBlockRequest, Operation.PUT_BLOCK), + new RequestMatcher(AzureBlobStore::isPutBlockListRequest, Operation.PUT_BLOCK_LIST), + new RequestMatcher( // https://docs.microsoft.com/en-us/rest/api/storageservices/put-blob#uri-parameters // The only URI parameter allowed for put-blob operation is "timeout", but if a sas token is used, // it's possible that the URI parameters contain additional parameters unrelated to the upload type. (httpMethod, url) -> httpMethod == HttpMethod.PUT && isPutBlockRequest(httpMethod, url) == false && isPutBlockListRequest(httpMethod, url) == false, - purpose -> statsCollectors.onSuccessfulRequest(Operation.PUT_BLOB, purpose) + Operation.PUT_BLOB ) ); - this.statsConsumer = (purpose, httpMethod, url) -> { + this.requestMetricsHandler = (purpose, method, url, metrics) -> { try { URI uri = url.toURI(); String path = uri.getPath() == null ? "" : uri.getPath(); @@ -167,9 +163,9 @@ && isPutBlockListRequest(httpMethod, url) == false, return; } - for (RequestStatsCollector requestStatsCollector : requestStatsCollectors) { - if (requestStatsCollector.shouldConsumeRequestInfo(httpMethod, url)) { - requestStatsCollector.consumeHttpRequestInfo(purpose); + for (RequestMatcher requestMatcher : requestMatchers) { + if (requestMatcher.matches(method, url)) { + requestMetricsRecorder.onRequestComplete(requestMatcher.operation, purpose, metrics); return; } } @@ -665,12 +661,12 @@ private BlobServiceAsyncClient asyncClient(OperationPurpose purpose) { } private AzureBlobServiceClient getAzureBlobServiceClientClient(OperationPurpose purpose) { - return service.client(clientName, locationMode, purpose, statsConsumer); + return service.client(clientName, locationMode, purpose, requestMetricsHandler); } @Override public Map stats() { - return statsCollectors.statsMap(service.isStateless()); + return requestMetricsRecorder.statsMap(service.isStateless()); } // visible for testing @@ -691,26 +687,43 @@ public String getKey() { Operation(String key) { this.key = key; } + + public static Operation fromKey(String key) { + for (Operation operation : Operation.values()) { + if (operation.key.equals(key)) { + return operation; + } + } + throw new IllegalArgumentException("No matching key: " + key); + } } - private record StatsKey(Operation operation, OperationPurpose purpose) { + // visible for testing + record StatsKey(Operation operation, OperationPurpose purpose) { @Override public String toString() { return purpose.getKey() + "_" + operation.getKey(); } } - private static class StatsCollectors { - final Map collectors = new ConcurrentHashMap<>(); + // visible for testing + class RequestMetricsRecorder { + private final RepositoriesMetrics repositoriesMetrics; + final Map opsCounters = new ConcurrentHashMap<>(); + final Map> opsAttributes = new ConcurrentHashMap<>(); + + RequestMetricsRecorder(RepositoriesMetrics repositoriesMetrics) { + this.repositoriesMetrics = repositoriesMetrics; + } Map statsMap(boolean stateless) { if (stateless) { - return collectors.entrySet() + return opsCounters.entrySet() .stream() .collect(Collectors.toUnmodifiableMap(e -> e.getKey().toString(), e -> e.getValue().sum())); } else { Map normalisedStats = Arrays.stream(Operation.values()).collect(Collectors.toMap(Operation::getKey, o -> 0L)); - collectors.forEach( + opsCounters.forEach( (key, value) -> normalisedStats.compute( key.operation.getKey(), (k, current) -> Objects.requireNonNull(current) + value.sum() @@ -720,11 +733,50 @@ Map statsMap(boolean stateless) { } } - public void onSuccessfulRequest(Operation operation, OperationPurpose purpose) { - collectors.computeIfAbsent(new StatsKey(operation, purpose), k -> new LongAdder()).increment(); + public void onRequestComplete(Operation operation, OperationPurpose purpose, AzureClientProvider.RequestMetrics requestMetrics) { + final StatsKey statsKey = new StatsKey(operation, purpose); + final LongAdder counter = opsCounters.computeIfAbsent(statsKey, k -> new LongAdder()); + final Map attributes = opsAttributes.computeIfAbsent( + statsKey, + k -> RepositoriesMetrics.createAttributesMap(repositoryMetadata, purpose, operation.getKey()) + ); + + counter.add(1); + + // range not satisfied is not retried, so we count them by checking the final response + if (requestMetrics.getStatusCode() == RestStatus.REQUESTED_RANGE_NOT_SATISFIED.getStatus()) { + repositoriesMetrics.requestRangeNotSatisfiedExceptionCounter().incrementBy(1, attributes); + } + + repositoriesMetrics.operationCounter().incrementBy(1, attributes); + if (RestStatus.isSuccessful(requestMetrics.getStatusCode()) == false) { + repositoriesMetrics.unsuccessfulOperationCounter().incrementBy(1, attributes); + } + + repositoriesMetrics.requestCounter().incrementBy(requestMetrics.getRequestCount(), attributes); + if (requestMetrics.getErrorCount() > 0) { + repositoriesMetrics.exceptionCounter().incrementBy(requestMetrics.getErrorCount(), attributes); + repositoriesMetrics.exceptionHistogram().record(requestMetrics.getErrorCount(), attributes); + } + + if (requestMetrics.getThrottleCount() > 0) { + repositoriesMetrics.throttleCounter().incrementBy(requestMetrics.getThrottleCount(), attributes); + repositoriesMetrics.throttleHistogram().record(requestMetrics.getThrottleCount(), attributes); + } + + // We use nanosecond precision, so a zero value indicates that no requests were executed + if (requestMetrics.getTotalRequestTimeNanos() > 0) { + repositoriesMetrics.httpRequestTimeInMillisHistogram() + .record(TimeUnit.NANOSECONDS.toMillis(requestMetrics.getTotalRequestTimeNanos()), attributes); + } } } + // visible for testing + RequestMetricsRecorder getMetricsRecorder() { + return requestMetricsRecorder; + } + private static class AzureInputStream extends InputStream { private final CancellableRateLimitedFluxIterator cancellableRateLimitedFluxIterator; private ByteBuf byteBuf; @@ -846,26 +898,11 @@ private ByteBuf getNextByteBuf() throws IOException { } } - private static class RequestStatsCollector { - private final BiPredicate filter; - private final Consumer onHttpRequest; - - private RequestStatsCollector(BiPredicate filter, Consumer onHttpRequest) { - this.filter = filter; - this.onHttpRequest = onHttpRequest; - } - - static RequestStatsCollector create(BiPredicate filter, Consumer consumer) { - return new RequestStatsCollector(filter, consumer); - } + private record RequestMatcher(BiPredicate filter, Operation operation) { - private boolean shouldConsumeRequestInfo(HttpMethod httpMethod, URL url) { + private boolean matches(HttpMethod httpMethod, URL url) { return filter.test(httpMethod, url); } - - private void consumeHttpRequestInfo(OperationPurpose operationPurpose) { - onHttpRequest.accept(operationPurpose); - } } OptionalBytesReference getRegister(OperationPurpose purpose, String blobPath, String containerPath, String blobKey) { diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java index ae497ff159576..654742c980268 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java @@ -24,6 +24,7 @@ import com.azure.core.http.HttpMethod; import com.azure.core.http.HttpPipelineCallContext; import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpPipelinePosition; import com.azure.core.http.HttpRequest; import com.azure.core.http.HttpResponse; import com.azure.core.http.ProxyOptions; @@ -44,11 +45,13 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.repositories.azure.executors.PrivilegedExecutor; import org.elasticsearch.repositories.azure.executors.ReactorScheduledExecutorService; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.netty4.NettyAllocator; import java.net.URL; import java.time.Duration; +import java.util.Optional; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadFactory; @@ -57,6 +60,8 @@ import static org.elasticsearch.repositories.azure.AzureRepositoryPlugin.REPOSITORY_THREAD_POOL_NAME; class AzureClientProvider extends AbstractLifecycleComponent { + private static final Logger logger = LogManager.getLogger(AzureClientProvider.class); + private static final TimeValue DEFAULT_CONNECTION_TIMEOUT = TimeValue.timeValueSeconds(30); private static final TimeValue DEFAULT_MAX_CONNECTION_IDLE_TIME = TimeValue.timeValueSeconds(60); private static final int DEFAULT_MAX_CONNECTIONS = 50; @@ -160,7 +165,7 @@ AzureBlobServiceClient createClient( LocationMode locationMode, RequestRetryOptions retryOptions, ProxyOptions proxyOptions, - SuccessfulRequestHandler successfulRequestHandler, + RequestMetricsHandler requestMetricsHandler, OperationPurpose purpose ) { if (closed) { @@ -189,8 +194,9 @@ AzureBlobServiceClient createClient( builder.credential(credentialBuilder.build()); } - if (successfulRequestHandler != null) { - builder.addPolicy(new SuccessfulRequestTracker(purpose, successfulRequestHandler)); + if (requestMetricsHandler != null) { + builder.addPolicy(new RequestMetricsTracker(purpose, requestMetricsHandler)); + builder.addPolicy(RetryMetricsTracker.INSTANCE); } if (locationMode.isSecondary()) { @@ -259,38 +265,135 @@ protected void doStop() { @Override protected void doClose() {} - private static final class SuccessfulRequestTracker implements HttpPipelinePolicy { - private static final Logger logger = LogManager.getLogger(SuccessfulRequestTracker.class); + static class RequestMetrics { + private volatile long totalRequestTimeNanos = 0; + private volatile int requestCount; + private volatile int errorCount; + private volatile int throttleCount; + private volatile int statusCode; + + int getRequestCount() { + return requestCount; + } + + int getErrorCount() { + return errorCount; + } + + int getStatusCode() { + return statusCode; + } + + int getThrottleCount() { + return throttleCount; + } + + /** + * Total time spent executing requests to complete operation in nanoseconds + */ + long getTotalRequestTimeNanos() { + return totalRequestTimeNanos; + } + + @Override + public String toString() { + return "RequestMetrics{" + + "totalRequestTimeNanos=" + + totalRequestTimeNanos + + ", requestCount=" + + requestCount + + ", errorCount=" + + errorCount + + ", throttleCount=" + + throttleCount + + ", statusCode=" + + statusCode + + '}'; + } + } + + private enum RetryMetricsTracker implements HttpPipelinePolicy { + INSTANCE; + + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + Optional metricsData = context.getData(RequestMetricsTracker.ES_REQUEST_METRICS_CONTEXT_KEY); + if (metricsData.isPresent() == false) { + assert false : "No metrics object associated with request " + context.getHttpRequest(); + return next.process(); + } + RequestMetrics metrics = (RequestMetrics) metricsData.get(); + metrics.requestCount++; + long requestStartTimeNanos = System.nanoTime(); + return next.process().doOnError(throwable -> { + metrics.totalRequestTimeNanos += System.nanoTime() - requestStartTimeNanos; + logger.debug("Detected error in RetryMetricsTracker", throwable); + metrics.errorCount++; + }).doOnSuccess(response -> { + metrics.totalRequestTimeNanos += System.nanoTime() - requestStartTimeNanos; + if (RestStatus.isSuccessful(response.getStatusCode()) == false) { + metrics.errorCount++; + // Azure always throttles with a 429 response, see + // https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/request-limits-and-throttling#error-code + if (response.getStatusCode() == RestStatus.TOO_MANY_REQUESTS.getStatus()) { + metrics.throttleCount++; + } + } + }); + } + + @Override + public HttpPipelinePosition getPipelinePosition() { + return HttpPipelinePosition.PER_RETRY; + } + } + + private static final class RequestMetricsTracker implements HttpPipelinePolicy { + private static final String ES_REQUEST_METRICS_CONTEXT_KEY = "_es_azure_repo_request_stats"; + private static final Logger logger = LogManager.getLogger(RequestMetricsTracker.class); private final OperationPurpose purpose; - private final SuccessfulRequestHandler onSuccessfulRequest; + private final RequestMetricsHandler requestMetricsHandler; - private SuccessfulRequestTracker(OperationPurpose purpose, SuccessfulRequestHandler onSuccessfulRequest) { + private RequestMetricsTracker(OperationPurpose purpose, RequestMetricsHandler requestMetricsHandler) { this.purpose = purpose; - this.onSuccessfulRequest = onSuccessfulRequest; + this.requestMetricsHandler = requestMetricsHandler; } @Override public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { - return next.process().doOnSuccess(httpResponse -> trackSuccessfulRequest(context.getHttpRequest(), httpResponse)); + final RequestMetrics requestMetrics = new RequestMetrics(); + context.setData(ES_REQUEST_METRICS_CONTEXT_KEY, requestMetrics); + return next.process().doOnSuccess((httpResponse) -> { + requestMetrics.statusCode = httpResponse.getStatusCode(); + trackCompletedRequest(context.getHttpRequest(), requestMetrics); + }).doOnError(throwable -> { + logger.debug("Detected error in RequestMetricsTracker", throwable); + trackCompletedRequest(context.getHttpRequest(), requestMetrics); + }); } - private void trackSuccessfulRequest(HttpRequest httpRequest, HttpResponse httpResponse) { + private void trackCompletedRequest(HttpRequest httpRequest, RequestMetrics requestMetrics) { HttpMethod method = httpRequest.getHttpMethod(); - if (httpResponse != null && method != null && httpResponse.getStatusCode() > 199 && httpResponse.getStatusCode() <= 299) { + if (method != null) { try { - onSuccessfulRequest.onSuccessfulRequest(purpose, method, httpRequest.getUrl()); + requestMetricsHandler.requestCompleted(purpose, method, httpRequest.getUrl(), requestMetrics); } catch (Exception e) { logger.warn("Unable to notify a successful request", e); } } } + + @Override + public HttpPipelinePosition getPipelinePosition() { + return HttpPipelinePosition.PER_CALL; + } } /** - * The {@link SuccessfulRequestTracker} calls this when a request completes successfully + * The {@link RequestMetricsTracker} calls this when a request completes */ - interface SuccessfulRequestHandler { + interface RequestMetricsHandler { - void onSuccessfulRequest(OperationPurpose purpose, HttpMethod method, URL url); + void requestCompleted(OperationPurpose purpose, HttpMethod method, URL url, RequestMetrics metrics); } } diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java index aec148adf9aa8..80e662343baee 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepository.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.repositories.RepositoriesMetrics; import org.elasticsearch.repositories.blobstore.MeteredBlobStoreRepository; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -91,6 +92,7 @@ public static final class Repository { private final ByteSizeValue chunkSize; private final AzureStorageService storageService; private final boolean readonly; + private final RepositoriesMetrics repositoriesMetrics; public AzureRepository( final RepositoryMetadata metadata, @@ -98,7 +100,8 @@ public AzureRepository( final AzureStorageService storageService, final ClusterService clusterService, final BigArrays bigArrays, - final RecoverySettings recoverySettings + final RecoverySettings recoverySettings, + final RepositoriesMetrics repositoriesMetrics ) { super( metadata, @@ -111,6 +114,7 @@ public AzureRepository( ); this.chunkSize = Repository.CHUNK_SIZE_SETTING.get(metadata.settings()); this.storageService = storageService; + this.repositoriesMetrics = repositoriesMetrics; // If the user explicitly did not define a readonly value, we set it by ourselves depending on the location mode setting. // For secondary_only setting, the repository should be read only @@ -152,7 +156,7 @@ protected BlobStore getBlobStore() { @Override protected AzureBlobStore createBlobStore() { - final AzureBlobStore blobStore = new AzureBlobStore(metadata, storageService, bigArrays); + final AzureBlobStore blobStore = new AzureBlobStore(metadata, storageService, bigArrays, repositoriesMetrics); logger.debug( () -> format( diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepositoryPlugin.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepositoryPlugin.java index c3cd5e78c5dbe..4556e63378fea 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepositoryPlugin.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureRepositoryPlugin.java @@ -71,7 +71,15 @@ public Map getRepositories( return Collections.singletonMap(AzureRepository.TYPE, metadata -> { AzureStorageService storageService = azureStoreService.get(); assert storageService != null; - return new AzureRepository(metadata, namedXContentRegistry, storageService, clusterService, bigArrays, recoverySettings); + return new AzureRepository( + metadata, + namedXContentRegistry, + storageService, + clusterService, + bigArrays, + recoverySettings, + repositoriesMetrics + ); }); } diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java index c6e85e44d24dd..7373ed9485784 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java @@ -91,7 +91,7 @@ public AzureBlobServiceClient client( String clientName, LocationMode locationMode, OperationPurpose purpose, - AzureClientProvider.SuccessfulRequestHandler successfulRequestHandler + AzureClientProvider.RequestMetricsHandler requestMetricsHandler ) { final AzureStorageSettings azureStorageSettings = getClientSettings(clientName); @@ -102,7 +102,7 @@ public AzureBlobServiceClient client( locationMode, retryOptions, proxyOptions, - successfulRequestHandler, + requestMetricsHandler, purpose ); } diff --git a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AbstractAzureServerTestCase.java b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AbstractAzureServerTestCase.java index 1962bddd8fdb3..cb9facc061a28 100644 --- a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AbstractAzureServerTestCase.java +++ b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AbstractAzureServerTestCase.java @@ -29,6 +29,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.mocksocket.MockHttpServer; +import org.elasticsearch.repositories.RepositoriesMetrics; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -168,7 +169,10 @@ int getMaxReadRetries(String clientName) { .build() ); - return new AzureBlobContainer(BlobPath.EMPTY, new AzureBlobStore(repositoryMetadata, service, BigArrays.NON_RECYCLING_INSTANCE)); + return new AzureBlobContainer( + BlobPath.EMPTY, + new AzureBlobStore(repositoryMetadata, service, BigArrays.NON_RECYCLING_INSTANCE, RepositoriesMetrics.NOOP) + ); } protected static byte[] randomBlobContent() { diff --git a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureClientProviderTests.java b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureClientProviderTests.java index 7d82f2d5029f6..2699438de8ac6 100644 --- a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureClientProviderTests.java +++ b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureClientProviderTests.java @@ -26,7 +26,7 @@ import java.util.concurrent.TimeUnit; public class AzureClientProviderTests extends ESTestCase { - private static final AzureClientProvider.SuccessfulRequestHandler EMPTY_CONSUMER = (purpose, method, url) -> {}; + private static final AzureClientProvider.RequestMetricsHandler NOOP_HANDLER = (purpose, method, url, metrics) -> {}; private ThreadPool threadPool; private AzureClientProvider azureClientProvider; @@ -76,7 +76,7 @@ public void testCanCreateAClientWithSecondaryLocation() { locationMode, requestRetryOptions, null, - EMPTY_CONSUMER, + NOOP_HANDLER, randomFrom(OperationPurpose.values()) ); } @@ -106,7 +106,7 @@ public void testCanNotCreateAClientWithSecondaryLocationWithoutAProperEndpoint() locationMode, requestRetryOptions, null, - EMPTY_CONSUMER, + NOOP_HANDLER, randomFrom(OperationPurpose.values()) ) ); diff --git a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureRepositorySettingsTests.java b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureRepositorySettingsTests.java index 7037dd4eaf111..3afacb5b7426e 100644 --- a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureRepositorySettingsTests.java +++ b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureRepositorySettingsTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.env.Environment; import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.repositories.RepositoriesMetrics; import org.elasticsearch.repositories.blobstore.BlobStoreTestUtil; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -40,7 +41,8 @@ private AzureRepository azureRepository(Settings settings) { mock(AzureStorageService.class), BlobStoreTestUtil.mockClusterService(), MockBigArrays.NON_RECYCLING_INSTANCE, - new RecoverySettings(settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)) + new RecoverySettings(settings, new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), + RepositoriesMetrics.NOOP ); assertThat(azureRepository.getBlobStore(), is(nullValue())); return azureRepository; diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java index bd5723b4dbcc4..3e6b7c356cb11 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.repositories.RepositoriesMetrics; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; @@ -144,16 +145,7 @@ class IgnoreNoResponseMetricsCollector extends RequestMetricCollector { private IgnoreNoResponseMetricsCollector(Operation operation, OperationPurpose purpose) { this.operation = operation; - this.attributes = Map.of( - "repo_type", - S3Repository.TYPE, - "repo_name", - repositoryMetadata.name(), - "operation", - operation.getKey(), - "purpose", - purpose.getKey() - ); + this.attributes = RepositoriesMetrics.createAttributesMap(repositoryMetadata, purpose, operation.getKey()); } @Override diff --git a/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java b/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java index 15f23f7511445..2598f943ea27c 100644 --- a/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java +++ b/modules/rest-root/src/main/java/org/elasticsearch/rest/root/TransportMainAction.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -48,7 +49,7 @@ protected void doExecute(Task task, MainRequest request, ActionListener responseRef = new AtomicReference<>(); - action.doExecute(mock(Task.class), new MainRequest(), new ActionListener<>() { - @Override - public void onResponse(MainResponse mainResponse) { - responseRef.set(mainResponse); - } - - @Override - public void onFailure(Exception e) { - logger.error("unexpected error", e); - } - }); + final AtomicBoolean listenerCalled = new AtomicBoolean(); + new TransportMainAction(settings, transportService, mock(ActionFilters.class), clusterService).doExecute( + mock(Task.class), + new MainRequest(), + ActionTestUtils.assertNoFailureListener(mainResponse -> { + assertNotNull(mainResponse); + assertEquals( + state.metadata().clusterUUIDCommitted() ? state.metadata().clusterUUID() : Metadata.UNKNOWN_CLUSTER_UUID, + mainResponse.getClusterUuid() + ); + assertFalse(listenerCalled.getAndSet(true)); + }) + ); - assertNotNull(responseRef.get()); + assertTrue(listenerCalled.get()); verify(clusterService, times(1)).state(); } } diff --git a/muted-tests.yml b/muted-tests.yml index 7764c0f8865d4..93893d7103afb 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -109,9 +109,6 @@ tests: - class: org.elasticsearch.xpack.ml.integration.MlJobIT method: testDeleteJobAsync issue: https://github.com/elastic/elasticsearch/issues/112212 -- class: org.elasticsearch.search.retriever.RankDocRetrieverBuilderIT - method: testRankDocsRetrieverWithCollapse - issue: https://github.com/elastic/elasticsearch/issues/112254 - class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT method: test {yaml=reference/rest-api/watcher/put-watch/line_120} issue: https://github.com/elastic/elasticsearch/issues/99517 @@ -336,15 +333,9 @@ tests: - class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT method: testPutE5Small_withPlatformAgnosticVariant issue: https://github.com/elastic/elasticsearch/issues/113983 -- class: org.elasticsearch.xpack.rank.rrf.RRFRankClientYamlTestSuiteIT - method: test {yaml=rrf/700_rrf_retriever_search_api_compatibility/rrf retriever with top-level collapse} - issue: https://github.com/elastic/elasticsearch/issues/114019 - class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT method: testPutE5WithTrainedModelAndInference issue: https://github.com/elastic/elasticsearch/issues/114023 -- class: org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilderIT - method: testRRFWithCollapse - issue: https://github.com/elastic/elasticsearch/issues/114074 - class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT method: testPutE5Small_withPlatformSpecificVariant issue: https://github.com/elastic/elasticsearch/issues/113950 @@ -354,9 +345,6 @@ tests: - class: org.elasticsearch.xpack.inference.InferenceCrudIT method: testGet issue: https://github.com/elastic/elasticsearch/issues/114135 -- class: org.elasticsearch.action.bulk.IncrementalBulkIT - method: testIncrementalBulkHighWatermarkBackOff - issue: https://github.com/elastic/elasticsearch/issues/114073 - class: org.elasticsearch.xpack.esql.expression.function.aggregate.AvgTests method: "testFold {TestCase= #7}" issue: https://github.com/elastic/elasticsearch/issues/114175 @@ -371,15 +359,17 @@ tests: issue: https://github.com/elastic/elasticsearch/issues/114187 - class: org.elasticsearch.xpack.esql.action.EsqlActionBreakerIT issue: https://github.com/elastic/elasticsearch/issues/114194 -- class: org.elasticsearch.index.query.SpanGapQueryBuilderTests - method: testToQuery - issue: https://github.com/elastic/elasticsearch/issues/114218 - class: org.elasticsearch.xpack.ilm.ExplainLifecycleIT method: testStepInfoPreservedOnAutoRetry issue: https://github.com/elastic/elasticsearch/issues/114220 - class: org.elasticsearch.xpack.inference.services.openai.OpenAiServiceTests method: testInfer_StreamRequest issue: https://github.com/elastic/elasticsearch/issues/114232 +- class: org.elasticsearch.logsdb.datageneration.DataGeneratorTests + method: testDataGeneratorProducesValidMappingAndDocument + issue: https://github.com/elastic/elasticsearch/issues/114188 +- class: org.elasticsearch.ingest.geoip.IpinfoIpDataLookupsTests + issue: https://github.com/elastic/elasticsearch/issues/114266 # Examples: # diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/SnapshotBasedRecoveryIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/SnapshotBasedRecoveryIT.java index 6f4c37f9e56a7..3343a683bbd11 100644 --- a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/SnapshotBasedRecoveryIT.java +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/SnapshotBasedRecoveryIT.java @@ -203,7 +203,7 @@ private void cancelShard(String indexName, int shard, String nodeName) throws IO } builder.endObject(); - Request request = new Request(HttpPost.METHOD_NAME, "/_cluster/reroute?pretty&metric=none"); + Request request = new Request(HttpPost.METHOD_NAME, "/_cluster/reroute?pretty"); request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); logger.info("--> Relocated primary to an older version {}", EntityUtils.toString(response.getEntity())); diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json b/rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json new file mode 100644 index 0000000000000..32b4b2f311837 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/inference.stream_inference.json @@ -0,0 +1,49 @@ +{ + "inference.stream_inference":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/post-stream-inference-api.html", + "description":"Perform streaming inference" + }, + "stability":"experimental", + "visibility":"public", + "headers":{ + "accept": [ "text/event-stream"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_inference/{inference_id}/_stream", + "methods":[ + "POST" + ], + "parts":{ + "inference_id":{ + "type":"string", + "description":"The inference Id" + } + } + }, + { + "path":"/_inference/{task_type}/{inference_id}/_stream", + "methods":[ + "POST" + ], + "parts":{ + "task_type":{ + "type":"string", + "description":"The task type" + }, + "inference_id":{ + "type":"string", + "description":"The inference Id" + } + } + } + ] + }, + "body":{ + "description":"The inference payload" + } + } +} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/10_basic.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/10_basic.yml index f7378cc01dc0a..d73efed1f7571 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/10_basic.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/10_basic.yml @@ -1,5 +1,8 @@ --- "Basic sanity check": + - requires: + cluster_features: ["cluster.reroute.ignores_metric_param"] + reason: requires this feature + - do: - cluster.reroute: - metric: none + cluster.reroute: {} diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/11_explain.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/11_explain.yml index 7543c96b232dc..3584ce9666705 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/11_explain.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/11_explain.yml @@ -13,12 +13,14 @@ setup: --- "Explain API with empty command list": + - requires: + cluster_features: ["cluster.reroute.ignores_metric_param"] + reason: requires this feature - do: cluster.reroute: explain: true dry_run: true - metric: none body: commands: [] @@ -26,6 +28,10 @@ setup: --- "Explain API for non-existent node & shard": + - requires: + cluster_features: ["cluster.reroute.ignores_metric_param"] + reason: requires this feature + - skip: features: [arbitrary_key] @@ -39,7 +45,6 @@ setup: cluster.reroute: explain: true dry_run: true - metric: none body: commands: - cancel: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/20_deprecated_response_filtering.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/20_deprecated_response_filtering.yml index 3bc27f53ad679..9775fbd1bea83 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/20_deprecated_response_filtering.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.reroute/20_deprecated_response_filtering.yml @@ -1,21 +1,45 @@ --- -"Do not return metadata by default and produce deprecation warning": +"Do not return metadata by default and emit no warning": + - requires: + cluster_features: ["cluster.reroute.ignores_metric_param"] + reason: requires this feature + + - do: + cluster.reroute: {} + - is_false: state + +--- +"Do not return metadata with ?metric=none and produce deprecation warning": + - requires: + cluster_features: ["cluster.reroute.ignores_metric_param"] + reason: requires this feature + - skip: features: [ "allowed_warnings" ] + - do: - cluster.reroute: {} + cluster.reroute: + metric: none allowed_warnings: - - "The [state] field in the response to the reroute API is deprecated and will be removed in a future version. Specify ?metric=none to adopt the future behaviour." - - is_false: state.metadata + - >- + the [?metric] query parameter to the [POST /_cluster/reroute] API has no effect; + its use will be forbidden in a future version + - is_false: state + --- -"If requested return metadata and produce deprecation warning": +"Do not return metadata with ?metric=metadata and produce deprecation warning": + - requires: + cluster_features: ["cluster.reroute.ignores_metric_param"] + reason: requires this feature + - skip: features: [ "allowed_warnings" ] + - do: cluster.reroute: metric: metadata allowed_warnings: - - "The [state] field in the response to the reroute API is deprecated and will be removed in a future version. Specify ?metric=none to adopt the future behaviour." - - is_true: state.metadata - - is_false: state.nodes - + - >- + the [?metric] query parameter to the [POST /_cluster/reroute] API has no effect; + its use will be forbidden in a future version + - is_false: state diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java deleted file mode 100644 index b78448bfd873f..0000000000000 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RankDocRetrieverBuilderIT.java +++ /dev/null @@ -1,756 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.retriever; - -import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.join.ScoreMode; -import org.apache.lucene.util.SetOnce; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.MultiSearchRequest; -import org.elasticsearch.action.search.MultiSearchResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchRequestBuilder; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.search.TransportMultiSearchAction; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.Maps; -import org.elasticsearch.index.query.InnerHitBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.MockSearchService; -import org.elasticsearch.search.aggregations.bucket.terms.Terms; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.elasticsearch.search.builder.PointInTimeBuilder; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.collapse.CollapseBuilder; -import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.sort.FieldSortBuilder; -import org.elasticsearch.search.sort.NestedSortBuilder; -import org.elasticsearch.search.sort.ScoreSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; -import org.elasticsearch.search.sort.SortBuilder; -import org.elasticsearch.search.sort.SortOrder; -import org.elasticsearch.test.ESIntegTestCase; -import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentType; -import org.junit.Before; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; -import static org.hamcrest.Matchers.equalTo; - -public class RankDocRetrieverBuilderIT extends ESIntegTestCase { - - @Override - protected Collection> nodePlugins() { - return List.of(MockSearchService.TestPlugin.class); - } - - public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} - - private static String INDEX = "test_index"; - private static final String ID_FIELD = "_id"; - private static final String DOC_FIELD = "doc"; - private static final String TEXT_FIELD = "text"; - private static final String VECTOR_FIELD = "vector"; - private static final String TOPIC_FIELD = "topic"; - private static final String LAST_30D_FIELD = "views.last30d"; - private static final String ALL_TIME_FIELD = "views.all"; - - @Before - public void setup() throws Exception { - String mapping = """ - { - "properties": { - "vector": { - "type": "dense_vector", - "dims": 3, - "element_type": "float", - "index": true, - "similarity": "l2_norm", - "index_options": { - "type": "hnsw" - } - }, - "text": { - "type": "text" - }, - "doc": { - "type": "keyword" - }, - "topic": { - "type": "keyword" - }, - "views": { - "type": "nested", - "properties": { - "last30d": { - "type": "integer" - }, - "all": { - "type": "integer" - } - } - } - } - } - """; - createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).build()); - admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); - indexDoc( - INDEX, - "doc_1", - DOC_FIELD, - "doc_1", - TOPIC_FIELD, - "technology", - TEXT_FIELD, - "the quick brown fox jumps over the lazy dog", - LAST_30D_FIELD, - 100 - ); - indexDoc( - INDEX, - "doc_2", - DOC_FIELD, - "doc_2", - TOPIC_FIELD, - "astronomy", - TEXT_FIELD, - "you know, for Search!", - VECTOR_FIELD, - new float[] { 1.0f, 2.0f, 3.0f }, - LAST_30D_FIELD, - 3 - ); - indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 6.0f, 6.0f, 6.0f }); - indexDoc( - INDEX, - "doc_4", - DOC_FIELD, - "doc_4", - TOPIC_FIELD, - "technology", - TEXT_FIELD, - "aardvark is a really awesome animal, but not very quick", - ALL_TIME_FIELD, - 100, - LAST_30D_FIELD, - 40 - ); - indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); - indexDoc( - INDEX, - "doc_6", - DOC_FIELD, - "doc_6", - TEXT_FIELD, - "quick quick quick quick search", - VECTOR_FIELD, - new float[] { 10.0f, 30.0f, 100.0f }, - LAST_30D_FIELD, - 15 - ); - indexDoc( - INDEX, - "doc_7", - DOC_FIELD, - "doc_7", - TOPIC_FIELD, - "biology", - TEXT_FIELD, - "dog", - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - ALL_TIME_FIELD, - 1000 - ); - refresh(INDEX); - } - - public void testRankDocsRetrieverBasicWithPagination() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 and with pagination, we'd just omit the first result - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - // include some pagination as well - source.from(1); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); - }); - } - - public void testRankDocsRetrieverWithAggs() { - // same as above, but we only want to bring back the top result from each subsearch - // so that would be 1, 2, and 7 - // and final rank would be (based on score): 2, 1, 7 - // aggs should still account for the same docs as the testRankDocsRetriever test, i.e. all but doc_5 - final int rankWindowSize = 1; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.size(1); - source.aggregation(new TermsAggregationBuilder("topic").field(TOPIC_FIELD)); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getHits().length, equalTo(1)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); - assertNotNull(resp.getAggregations()); - assertNotNull(resp.getAggregations().get("topic")); - Terms terms = resp.getAggregations().get("topic"); - // doc_3 is not part of the final aggs computation as it is only retrieved through the knn retriever - // and is outside of the rank window - assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(2L)); - assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); - assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); - }); - } - - public void testRankDocsRetrieverWithCollapse() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 - // with collapsing on topic field we would have 6, 2, 1, 7 - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.collapse( - new CollapseBuilder(TOPIC_FIELD).setInnerHits( - new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) - ) - ); - source.fetchField(TOPIC_FIELD); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getHits().length, equalTo(4)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(1).field(TOPIC_FIELD).getValue().toString(), equalTo("astronomy")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).field(TOPIC_FIELD).getValue().toString(), equalTo("technology")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); - assertThat(resp.getHits().getAt(3).field(TOPIC_FIELD).getValue().toString(), equalTo("biology")); - }); - } - - public void testRankDocsRetrieverWithNestedCollapseAndAggs() { - final int rankWindowSize = 10; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1 and 6 as doc_4 is collapsed to doc_1 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - standard0.collapseBuilder = new CollapseBuilder(TOPIC_FIELD).setInnerHits( - new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) - ); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.aggregation(new TermsAggregationBuilder("topic").field(TOPIC_FIELD)); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertNotNull(resp.getAggregations()); - assertNotNull(resp.getAggregations().get("topic")); - Terms terms = resp.getAggregations().get("topic"); - // doc_3 is not part of the final aggs computation as it is only retrieved through the knn retriever - // and is outside of the rank window - assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); - assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); - assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); - }); - } - - public void testRankDocsRetrieverWithNestedQuery() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gt(10L), ScoreMode.Avg); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 3, 4, 7 - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ) - ); - source.fetchField(TOPIC_FIELD); - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_7")); - }); - } - - public void testRankDocsRetrieverMultipleCompoundRetrievers() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, and 6 - standard0.queryBuilder = QueryBuilders.boolQuery() - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(9L)) - .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(8L)); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 2 and 6 due to prefilter - standard1.queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(ID_FIELD, "doc_2", "doc_3", "doc_6")).boost(20L); - standard1.preFilterQueryBuilders.add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 7, 2, 3, and 6 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder( - VECTOR_FIELD, - new float[] { 3.0f, 3.0f, 3.0f }, - null, - 10, - 100, - null - ); - // the compound retriever here produces a score for a doc based on the percentage of the queries that it was matched on and - // resolves ties based on actual score, rank, and then the doc (we're forcing 1 shard for consistent results) - // so ideal rank would be: 6, 2, 1, 4, 7, 3 - CompoundRetrieverWithRankDocs compoundRetriever1 = new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList( - new RetrieverSource(standard0, null), - new RetrieverSource(standard1, null), - new RetrieverSource(knnRetrieverBuilder, null) - ) - ); - // simple standard retriever that would have the doc_4 as its first (and only) result - StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder(); - standard2.queryBuilder = QueryBuilders.queryStringQuery("aardvark").defaultField(TEXT_FIELD); - - // combining the two retrievers would bring doc_4 at the top as it would be the only one present in both doc sets - // the rest of the docs would be sorted based on their ranks as they have the same score (1/2) - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList(new RetrieverSource(compoundRetriever1, null), new RetrieverSource(standard2, null)) - ) - ); - - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(6L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_7")); - }); - } - - public void testRankDocsRetrieverDifferentNestedSorting() { - final int rankWindowSize = 100; - SearchSourceBuilder source = new SearchSourceBuilder(); - StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(); - // this one retrieves docs 1, 4, 6, 2 - standard0.queryBuilder = QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gt(0), ScoreMode.Avg); - standard0.sortBuilders = List.of( - new FieldSortBuilder(LAST_30D_FIELD).setNestedSort(new NestedSortBuilder("views")).order(SortOrder.DESC) - ); - StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder(); - // this one retrieves docs 4, 7 - standard1.queryBuilder = QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(ALL_TIME_FIELD).gt(0), ScoreMode.Avg); - standard1.sortBuilders = List.of( - new FieldSortBuilder(ALL_TIME_FIELD).setNestedSort(new NestedSortBuilder("views")).order(SortOrder.ASC) - ); - - source.retriever( - new CompoundRetrieverWithRankDocs( - rankWindowSize, - Arrays.asList(new RetrieverSource(standard0, null), new RetrieverSource(standard1, null)) - ) - ); - - SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); - ElasticsearchAssertions.assertResponse(req, resp -> { - assertNull(resp.pointInTimeId()); - assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); - assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); - assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_7")); - }); - } - - class CompoundRetrieverWithRankDocs extends RetrieverBuilder { - - private final List sources; - private final int rankWindowSize; - - private CompoundRetrieverWithRankDocs(int rankWindowSize, List sources) { - this.rankWindowSize = rankWindowSize; - this.sources = Collections.unmodifiableList(sources); - } - - @Override - public boolean isCompound() { - return true; - } - - @Override - public QueryBuilder topDocsQuery() { - throw new UnsupportedOperationException("should not be called"); - } - - @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (ctx.getPointInTimeBuilder() == null) { - throw new IllegalStateException("PIT is required"); - } - - // Rewrite prefilters - boolean hasChanged = false; - var newPreFilters = rewritePreFilters(ctx); - hasChanged |= newPreFilters != preFilterQueryBuilders; - - // Rewrite retriever sources - List newRetrievers = new ArrayList<>(); - for (var entry : sources) { - RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); - if (newRetriever != entry.retriever) { - newRetrievers.add(new RetrieverSource(newRetriever, null)); - hasChanged |= newRetriever != entry.retriever; - } else if (newRetriever == entry.retriever) { - var sourceBuilder = entry.source != null - ? entry.source - : createSearchSourceBuilder(ctx.getPointInTimeBuilder(), newRetriever); - var rewrittenSource = sourceBuilder.rewrite(ctx); - newRetrievers.add(new RetrieverSource(newRetriever, rewrittenSource)); - hasChanged |= rewrittenSource != entry.source; - } - } - if (hasChanged) { - return new CompoundRetrieverWithRankDocs(rankWindowSize, newRetrievers); - } - - // execute searches - final SetOnce results = new SetOnce<>(); - final MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - for (var entry : sources) { - SearchRequest searchRequest = new SearchRequest().source(entry.source); - // The can match phase can reorder shards, so we disable it to ensure the stable ordering - searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); - multiSearchRequest.add(searchRequest); - } - ctx.registerAsyncAction((client, listener) -> { - client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<>() { - @Override - public void onResponse(MultiSearchResponse items) { - List topDocs = new ArrayList<>(); - for (int i = 0; i < items.getResponses().length; i++) { - var item = items.getResponses()[i]; - var rankDocs = getRankDocs(item.getResponse()); - sources.get(i).retriever().setRankDocs(rankDocs); - topDocs.add(rankDocs); - } - results.set(combineResults(topDocs)); - listener.onResponse(null); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - }); - - return new RankDocsRetrieverBuilder( - rankWindowSize, - newRetrievers.stream().map(s -> s.retriever).toList(), - results::get, - newPreFilters - ); - } - - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - throw new UnsupportedOperationException("should not be called"); - } - - @Override - public String getName() { - return "compound_retriever"; - } - - @Override - protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - - } - - @Override - protected boolean doEquals(Object o) { - return false; - } - - @Override - protected int doHashCode() { - return 0; - } - - private RankDoc[] getRankDocs(SearchResponse searchResponse) { - assert searchResponse != null; - int size = Math.min(rankWindowSize, searchResponse.getHits().getHits().length); - RankDoc[] docs = new RankDoc[size]; - for (int i = 0; i < size; i++) { - var hit = searchResponse.getHits().getAt(i); - long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; - int doc = ShardDocSortField.decodeDoc(sortValue); - int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); - docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex); - docs[i].rank = i + 1; - } - return docs; - } - - record RankDocAndHitRatio(RankDoc rankDoc, float hitRatio) {} - - /** - * Combines the provided {@code rankResults} to return the final top documents. - */ - public RankDoc[] combineResults(List rankResults) { - int totalQueries = rankResults.size(); - final float step = 1.0f / totalQueries; - Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); - for (var rankResult : rankResults) { - for (RankDoc scoreDoc : rankResult) { - docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> { - if (value == null) { - RankDoc res = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); - res.rank = scoreDoc.rank; - return new RankDocAndHitRatio(res, step); - } else { - RankDoc res = new RankDoc(scoreDoc.doc, Math.max(scoreDoc.score, value.rankDoc.score), scoreDoc.shardIndex); - res.rank = Math.min(scoreDoc.rank, value.rankDoc.rank); - return new RankDocAndHitRatio(res, value.hitRatio + step); - } - }); - } - } - // sort the results based on hit ratio, then doc, then rank, and final tiebreaker is based on smaller doc id - RankDocAndHitRatio[] sortedResults = docsToRankResults.values().toArray(RankDocAndHitRatio[]::new); - Arrays.sort(sortedResults, (RankDocAndHitRatio doc1, RankDocAndHitRatio doc2) -> { - if (doc1.hitRatio != doc2.hitRatio) { - return doc1.hitRatio < doc2.hitRatio ? 1 : -1; - } - if (false == (Float.isNaN(doc1.rankDoc.score) || Float.isNaN(doc2.rankDoc.score)) - && (doc1.rankDoc.score != doc2.rankDoc.score)) { - return doc1.rankDoc.score < doc2.rankDoc.score ? 1 : -1; - } - if (doc1.rankDoc.rank != doc2.rankDoc.rank) { - return doc1.rankDoc.rank < doc2.rankDoc.rank ? -1 : 1; - } - return doc1.rankDoc.doc < doc2.rankDoc.doc ? -1 : 1; - }); - // trim the results if needed, otherwise each shard will always return `rank_window_size` results. - // pagination and all else will happen on the coordinator when combining the shard responses - RankDoc[] topResults = new RankDoc[Math.min(rankWindowSize, sortedResults.length)]; - for (int rank = 0; rank < topResults.length; ++rank) { - topResults[rank] = sortedResults[rank].rankDoc; - topResults[rank].rank = rank + 1; - topResults[rank].score = sortedResults[rank].hitRatio; - } - return topResults; - } - } - - private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit).trackTotalHits(false).size(100); - retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, false); - - // Record the shard id in the sort result - List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); - if (sortBuilders.isEmpty()) { - sortBuilders.add(new ScoreSortBuilder()); - } - sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); - sourceBuilder.sort(sortBuilders); - return sourceBuilder; - } -} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 56672957dd571..11965abf1dcd2 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -429,6 +429,7 @@ org.elasticsearch.indices.IndicesFeatures, org.elasticsearch.repositories.RepositoriesFeatures, org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures, + org.elasticsearch.rest.action.admin.cluster.ClusterRerouteFeatures, org.elasticsearch.index.mapper.MapperFeatures, org.elasticsearch.ingest.IngestGeoIpFeatures, org.elasticsearch.search.SearchFeatures, diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f6e4649aa4807..2095ba47ee377 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -235,6 +235,8 @@ static TransportVersion def(int id) { public static final TransportVersion SEARCH_FAILURE_STATS = def(8_759_00_0); public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0); public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0); + public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0); + public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponse.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponse.java index c10b9dc5c9a01..7b344a4c25a1b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponse.java @@ -21,7 +21,6 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentObject; import org.elasticsearch.core.RestApiVersion; import org.elasticsearch.core.UpdateForV10; -import org.elasticsearch.core.UpdateForV9; import org.elasticsearch.rest.action.search.RestSearchAction; import org.elasticsearch.xcontent.ToXContent; @@ -43,7 +42,6 @@ public class ClusterRerouteResponse extends ActionResponse implements IsAcknowle /** * To be removed when REST compatibility with {@link org.elasticsearch.Version#V_8_6_0} / {@link RestApiVersion#V_8} no longer needed */ - @UpdateForV9(owner = UpdateForV9.Owner.DISTRIBUTED_COORDINATION) // to remove from the v9 API only @UpdateForV10(owner = UpdateForV10.Owner.DISTRIBUTED_COORDINATION) // to remove entirely private final ClusterState state; private final RoutingExplanations explanations; diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java index 7857e9a22e9b9..cb667400240f0 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java @@ -23,7 +23,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.injection.guice.Inject; @@ -120,27 +119,18 @@ public void onPrimaryOperationComplete( ActionListener listener ) { assert replicaRequest.primaryRefreshResult.refreshed() : "primary has not refreshed"; - boolean fastRefresh = IndexSettings.INDEX_FAST_REFRESH_SETTING.get( - clusterService.state().metadata().index(indexShardRoutingTable.shardId().getIndex()).getSettings() + UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest( + indexShardRoutingTable, + replicaRequest.primaryRefreshResult.primaryTerm(), + replicaRequest.primaryRefreshResult.generation(), + false + ); + transportService.sendRequest( + transportService.getLocalNode(), + TransportUnpromotableShardRefreshAction.NAME, + unpromotableReplicaRequest, + new ActionListenerResponseHandler<>(listener.safeMap(r -> null), in -> ActionResponse.Empty.INSTANCE, refreshExecutor) ); - - // Indices marked with fast refresh do not rely on refreshing the unpromotables - if (fastRefresh) { - listener.onResponse(null); - } else { - UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest( - indexShardRoutingTable, - replicaRequest.primaryRefreshResult.primaryTerm(), - replicaRequest.primaryRefreshResult.generation(), - false - ); - transportService.sendRequest( - transportService.getLocalNode(), - TransportUnpromotableShardRefreshAction.NAME, - unpromotableReplicaRequest, - new ActionListenerResponseHandler<>(listener.safeMap(r -> null), in -> ActionResponse.Empty.INSTANCE, refreshExecutor) - ); - } } } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java index 6c24ec2d17604..f91a983d47885 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java @@ -24,6 +24,9 @@ import java.util.List; +import static org.elasticsearch.TransportVersions.FAST_REFRESH_RCO; +import static org.elasticsearch.index.IndexSettings.INDEX_FAST_REFRESH_SETTING; + public class TransportUnpromotableShardRefreshAction extends TransportBroadcastUnpromotableAction< UnpromotableShardRefreshRequest, ActionResponse.Empty> { @@ -73,6 +76,18 @@ protected void unpromotableShardOperation( return; } + // During an upgrade to FAST_REFRESH_RCO, we expect search shards to be first upgraded before the primary is upgraded. Thus, + // when the primary is upgraded, and starts to deliver unpromotable refreshes, we expect the search shards to be upgraded already. + // Note that the fast refresh setting is final. + // TODO: remove assertion (ES-9563) + assert INDEX_FAST_REFRESH_SETTING.get(shard.indexSettings().getSettings()) == false + || transportService.getLocalNodeConnection().getTransportVersion().onOrAfter(FAST_REFRESH_RCO) + : "attempted to refresh a fast refresh search shard " + + shard + + " on transport version " + + transportService.getLocalNodeConnection().getTransportVersion() + + " (before FAST_REFRESH_RCO)"; + ActionListener.run(responseListener, listener -> { shard.waitForPrimaryTermAndGeneration( request.getPrimaryTerm(), diff --git a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java index d5ad3aa2d29a1..58ffe25e08e49 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/IncrementalBulkService.java @@ -194,7 +194,7 @@ public void lastItems(List> items, Releasable releasable, Act releasables.clear(); // We do not need to set this back to false as this will be the last request. bulkInProgress = true; - client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() { + client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() { private final boolean isFirstRequest = incrementalRequestSubmitted == false; diff --git a/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java b/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java index 189aa1c95d865..99eac250641ae 100644 --- a/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java +++ b/server/src/main/java/org/elasticsearch/action/get/TransportGetAction.java @@ -125,11 +125,10 @@ protected void asyncShardOperation(GetRequest request, ShardId shardId, ActionLi IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); if (indexShard.routingEntry().isPromotableToPrimary() == false) { - assert indexShard.indexSettings().isFastRefresh() == false - : "a search shard should not receive a TransportGetAction for an index with fast refresh"; handleGetOnUnpromotableShard(request, indexShard, listener); return; } + // TODO: adapt assertion to assert only that it is not stateless (ES-9563) assert DiscoveryNode.isStateless(clusterService.getSettings()) == false || indexShard.indexSettings().isFastRefresh() : "in Stateless a promotable to primary shard can receive a TransportGetAction only if an index has the fast refresh setting"; if (request.realtime()) { // we are not tied to a refresh cycle here anyway diff --git a/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java b/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java index 8d5760307c3fe..633e7ef6793ab 100644 --- a/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java +++ b/server/src/main/java/org/elasticsearch/action/get/TransportShardMultiGetAction.java @@ -124,11 +124,10 @@ protected void asyncShardOperation(MultiGetShardRequest request, ShardId shardId IndexService indexService = indicesService.indexServiceSafe(shardId.getIndex()); IndexShard indexShard = indexService.getShard(shardId.id()); if (indexShard.routingEntry().isPromotableToPrimary() == false) { - assert indexShard.indexSettings().isFastRefresh() == false - : "a search shard should not receive a TransportShardMultiGetAction for an index with fast refresh"; handleMultiGetOnUnpromotableShard(request, indexShard, listener); return; } + // TODO: adapt assertion to assert only that it is not stateless (ES-9563) assert DiscoveryNode.isStateless(clusterService.getSettings()) == false || indexShard.indexSettings().isFastRefresh() : "in Stateless a promotable to primary shard can receive a TransportShardMultiGetAction only if an index has " + "the fast refresh setting"; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index a6acb3ee2a52e..1c4eb1c191370 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -36,6 +36,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.SearchSortValues; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -51,6 +52,7 @@ import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext; +import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.search.suggest.Suggest.Suggestion; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; @@ -464,6 +466,13 @@ private static SearchHits getHits( assert shardDoc instanceof RankDoc; searchHit.setRank(((RankDoc) shardDoc).rank); searchHit.score(shardDoc.score); + long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc); + searchHit.sortValues( + new SearchSortValues( + new Object[] { shardDoc.score, shardAndDoc }, + new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW } + ) + ); } else if (sortedTopDocs.isSortedByField) { FieldDoc fieldDoc = (FieldDoc) shardDoc; searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats); diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java b/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java index 683c3589c893d..7414aeeb2c405 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.translog.Translog; @@ -53,9 +52,7 @@ public void refreshShard( case WAIT_UNTIL -> waitUntil(indexShard, location, new ActionListener<>() { @Override public void onResponse(Boolean forced) { - // Fast refresh indices do not depend on the unpromotables being refreshed - boolean fastRefresh = IndexSettings.INDEX_FAST_REFRESH_SETTING.get(indexShard.indexSettings().getSettings()); - if (location != null && (indexShard.routingEntry().isSearchable() == false && fastRefresh == false)) { + if (location != null && indexShard.routingEntry().isSearchable() == false) { refreshUnpromotables(indexShard, location, listener, forced, postWriteRefreshTimeout); } else { listener.onResponse(forced); @@ -68,9 +65,7 @@ public void onFailure(Exception e) { } }); case IMMEDIATE -> immediate(indexShard, listener.delegateFailureAndWrap((l, r) -> { - // Fast refresh indices do not depend on the unpromotables being refreshed - boolean fastRefresh = IndexSettings.INDEX_FAST_REFRESH_SETTING.get(indexShard.indexSettings().getSettings()); - if (indexShard.getReplicationGroup().getRoutingTable().unpromotableShards().size() > 0 && fastRefresh == false) { + if (indexShard.getReplicationGroup().getRoutingTable().unpromotableShards().size() > 0) { sendUnpromotableRequests(indexShard, r.generation(), true, l, postWriteRefreshTimeout); } else { l.onResponse(true); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java index 566571d82c8ab..0756080c16d00 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java @@ -695,6 +695,11 @@ public long version() { return this.version; } + /** + * @return A UUID which identifies this cluster. Nodes record the UUID of the cluster they first join on disk, and will then refuse to + * join clusters with different UUIDs. Note that when the cluster is forming for the first time this value may not yet be committed, + * and therefore it may change. Check {@link #clusterUUIDCommitted()} to verify that the value is committed if needed. + */ public String clusterUUID() { return this.clusterUUID; } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java b/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java index f7812d284f2af..9120e25b443d7 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/OperationRouting.java @@ -32,6 +32,7 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.TransportVersions.FAST_REFRESH_RCO; import static org.elasticsearch.index.IndexSettings.INDEX_FAST_REFRESH_SETTING; public class OperationRouting { @@ -305,8 +306,14 @@ public ShardId shardId(ClusterState clusterState, String index, String id, @Null } public static boolean canSearchShard(ShardRouting shardRouting, ClusterState clusterState) { + // TODO: remove if and always return isSearchable (ES-9563) if (INDEX_FAST_REFRESH_SETTING.get(clusterState.metadata().index(shardRouting.index()).getSettings())) { - return shardRouting.isPromotableToPrimary(); + // Until all the cluster is upgraded, we send searches/gets to the primary (even if it has been upgraded) to execute locally. + if (clusterState.getMinTransportVersion().onOrAfter(FAST_REFRESH_RCO)) { + return shardRouting.isSearchable(); + } else { + return shardRouting.isPromotableToPrimary(); + } } else { return shardRouting.isSearchable(); } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java index b20cd3ecaf992..a55522ff14c83 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/MaxRetryAllocationDecider.java @@ -35,7 +35,7 @@ public class MaxRetryAllocationDecider extends AllocationDecider { Setting.Property.NotCopyableOnResize ); - private static final String RETRY_FAILED_API = "POST /_cluster/reroute?retry_failed&metric=none"; + private static final String RETRY_FAILED_API = "POST /_cluster/reroute?retry_failed"; public static final String NAME = "max_retry"; diff --git a/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java b/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java index c19e3ca353569..3b37afc3b297b 100644 --- a/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java +++ b/server/src/main/java/org/elasticsearch/index/cache/bitset/BitsetFilterCache.java @@ -105,7 +105,7 @@ static boolean shouldLoadRandomAccessFiltersEagerly(IndexSettings settings) { boolean loadFiltersEagerlySetting = settings.getValue(INDEX_LOAD_RANDOM_ACCESS_FILTERS_EAGERLY_SETTING); boolean isStateless = DiscoveryNode.isStateless(settings.getNodeSettings()); if (isStateless) { - return DiscoveryNode.hasRole(settings.getNodeSettings(), DiscoveryNodeRole.INDEX_ROLE) + return DiscoveryNode.hasRole(settings.getNodeSettings(), DiscoveryNodeRole.SEARCH_ROLE) && loadFiltersEagerlySetting && INDEX_FAST_REFRESH_SETTING.get(settings.getSettings()); } else { diff --git a/server/src/main/java/org/elasticsearch/repositories/RepositoriesMetrics.java b/server/src/main/java/org/elasticsearch/repositories/RepositoriesMetrics.java index cce3c764fe7a4..2cd6e2b11ef7a 100644 --- a/server/src/main/java/org/elasticsearch/repositories/RepositoriesMetrics.java +++ b/server/src/main/java/org/elasticsearch/repositories/RepositoriesMetrics.java @@ -9,10 +9,17 @@ package org.elasticsearch.repositories; +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.LongHistogram; import org.elasticsearch.telemetry.metric.MeterRegistry; +import java.util.Map; + +/** + * The common set of metrics that we publish for {@link org.elasticsearch.repositories.blobstore.BlobStoreRepository} implementations. + */ public record RepositoriesMetrics( MeterRegistry meterRegistry, LongCounter requestCounter, @@ -28,15 +35,65 @@ public record RepositoriesMetrics( public static RepositoriesMetrics NOOP = new RepositoriesMetrics(MeterRegistry.NOOP); + /** + * Is incremented for each request sent to the blob store (including retries) + * + * Exposed as {@link #requestCounter()} + */ public static final String METRIC_REQUESTS_TOTAL = "es.repositories.requests.total"; + /** + * Is incremented for each request which returns a non 2xx response OR fails to return a response + * (includes throttling and retryable errors) + * + * Exposed as {@link #exceptionCounter()} + */ public static final String METRIC_EXCEPTIONS_TOTAL = "es.repositories.exceptions.total"; + /** + * Is incremented each time an operation ends with a 416 response + * + * Exposed as {@link #requestRangeNotSatisfiedExceptionCounter()} + */ public static final String METRIC_EXCEPTIONS_REQUEST_RANGE_NOT_SATISFIED_TOTAL = "es.repositories.exceptions.request_range_not_satisfied.total"; + /** + * Is incremented each time we are throttled by the blob store, e.g. upon receiving an HTTP 429 response + * + * Exposed as {@link #throttleCounter()} + */ public static final String METRIC_THROTTLES_TOTAL = "es.repositories.throttles.total"; + /** + * Is incremented for each operation we attempt, whether it succeeds or fails, this doesn't include retries + * + * Exposed via {@link #operationCounter()} + */ public static final String METRIC_OPERATIONS_TOTAL = "es.repositories.operations.total"; + /** + * Is incremented for each operation that ends with a non 2xx response or throws an exception + * + * Exposed via {@link #unsuccessfulOperationCounter()} + */ public static final String METRIC_UNSUCCESSFUL_OPERATIONS_TOTAL = "es.repositories.operations.unsuccessful.total"; + /** + * Each time an operation has one or more failed requests (from non 2xx response or exception), the + * count of those is sampled + * + * Exposed via {@link #exceptionHistogram()} + */ public static final String METRIC_EXCEPTIONS_HISTOGRAM = "es.repositories.exceptions.histogram"; + /** + * Each time an operation has one or more throttled requests, the count of those is sampled + * + * Exposed via {@link #throttleHistogram()} + */ public static final String METRIC_THROTTLES_HISTOGRAM = "es.repositories.throttles.histogram"; + /** + * Every operation that is attempted will record a time. The value recorded here is the sum of the duration of + * each of the requests executed to try and complete the operation. The duration of each request is the time + * between sending the request and either a response being received, or the request failing. Does not include + * the consumption of the body of the response or any time spent pausing between retries. + * + * Exposed via {@link #httpRequestTimeInMillisHistogram()} + */ public static final String HTTP_REQUEST_TIME_IN_MILLIS_HISTOGRAM = "es.repositories.requests.http_request_time.histogram"; public RepositoriesMetrics(MeterRegistry meterRegistry) { @@ -61,4 +118,25 @@ public RepositoriesMetrics(MeterRegistry meterRegistry) { ) ); } + + /** + * Create the map of attributes we expect to see on repository metrics + */ + public static Map createAttributesMap( + RepositoryMetadata repositoryMetadata, + OperationPurpose purpose, + String operation + ) { + return Map.of( + "repo_type", + repositoryMetadata.type(), + "repo_name", + repositoryMetadata.name(), + "operation", + operation, + "purpose", + purpose.getKey() + ); + } + } diff --git a/server/src/main/java/org/elasticsearch/rest/RestStatus.java b/server/src/main/java/org/elasticsearch/rest/RestStatus.java index 72227b2d26ec0..569b63edda00b 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestStatus.java +++ b/server/src/main/java/org/elasticsearch/rest/RestStatus.java @@ -571,4 +571,16 @@ public static RestStatus status(int successfulShards, int totalShards, ShardOper public static RestStatus fromCode(int code) { return CODE_TO_STATUS.get(code); } + + /** + * Utility method to determine if an HTTP status code is "Successful" + * + * as defined by RFC 9110 + * + * @param code An HTTP status code + * @return true if it is a 2xx code, false otherwise + */ + public static boolean isSuccessful(int code) { + return code >= 200 && code < 300; + } } diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/ClusterRerouteFeatures.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/ClusterRerouteFeatures.java new file mode 100644 index 0000000000000..c6582cab4a2da --- /dev/null +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/ClusterRerouteFeatures.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.rest.action.admin.cluster; + +import org.elasticsearch.features.FeatureSpecification; +import org.elasticsearch.features.NodeFeature; + +import java.util.Set; + +public class ClusterRerouteFeatures implements FeatureSpecification { + public static final NodeFeature CLUSTER_REROUTE_IGNORES_METRIC_PARAM = new NodeFeature("cluster.reroute.ignores_metric_param"); + + @Override + public Set getFeatures() { + return Set.of(CLUSTER_REROUTE_IGNORES_METRIC_PARAM); + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterRerouteAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterRerouteAction.java index 66d6aee30d00a..fada07d60b74e 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterRerouteAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterRerouteAction.java @@ -15,8 +15,11 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.allocation.command.AllocationCommands; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.logging.DeprecationCategory; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; @@ -39,6 +42,8 @@ @ServerlessScope(Scope.INTERNAL) public class RestClusterRerouteAction extends BaseRestHandler { + private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestClusterRerouteAction.class); + private static final Set RESPONSE_PARAMS = addToCopy(Settings.FORMAT_PARAMS, "metric"); private static final ObjectParser PARSER = new ObjectParser<>("cluster_reroute"); @@ -51,7 +56,8 @@ public class RestClusterRerouteAction extends BaseRestHandler { PARSER.declareBoolean(ClusterRerouteRequest::dryRun, new ParseField("dry_run")); } - private static final String DEFAULT_METRICS = Strings.arrayToCommaDelimitedString( + @UpdateForV10(owner = UpdateForV10.Owner.DISTRIBUTED_COORDINATION) // no longer used, so can be removed + private static final String V8_DEFAULT_METRICS = Strings.arrayToCommaDelimitedString( EnumSet.complementOf(EnumSet.of(ClusterState.Metric.METADATA)).toArray() ); @@ -76,6 +82,11 @@ public boolean allowSystemIndexAccessByDefault() { return true; } + @UpdateForV10(owner = UpdateForV10.Owner.DISTRIBUTED_COORDINATION) + // actually UpdateForV11 because V10 still supports the V9 API including this deprecation message + private static final String METRIC_DEPRECATION_MESSAGE = """ + the [?metric] query parameter to the [POST /_cluster/reroute] API has no effect; its use will be forbidden in a future version"""; + @Override public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { ClusterRerouteRequest clusterRerouteRequest = createRequest(request); @@ -83,11 +94,24 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC if (clusterRerouteRequest.explain()) { request.params().put("explain", Boolean.TRUE.toString()); } - // by default, return everything but metadata - final String metric = request.param("metric"); - if (metric == null) { - request.params().put("metric", DEFAULT_METRICS); + + switch (request.getRestApiVersion()) { + case V_9 -> { + // always avoid returning the cluster state by forcing `?metric=none`; emit a warning if `?metric` is even present + if (request.hasParam("metric")) { + deprecationLogger.critical(DeprecationCategory.API, "cluster-reroute-metric-param", METRIC_DEPRECATION_MESSAGE); + } + request.params().put("metric", "none"); + } + case V_8, V_7 -> { + // by default, return everything but metadata + final String metric = request.param("metric"); + if (metric == null) { + request.params().put("metric", V8_DEFAULT_METRICS); + } + } } + return channel -> client.execute( TransportClusterRerouteAction.TYPE, clusterRerouteRequest, diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java index b16a234931115..9ab14aa9362b5 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java @@ -11,9 +11,11 @@ import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreDoc; -import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentFragment; import org.elasticsearch.xcontent.XContentBuilder; @@ -24,7 +26,7 @@ * {@code RankDoc} is the base class for all ranked results. * Subclasses should extend this with additional information required for their global ranking method. */ -public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable { +public class RankDoc extends ScoreDoc implements VersionedNamedWriteable, ToXContentFragment, Comparable { public static final String NAME = "rank_doc"; @@ -40,6 +42,11 @@ public String getWriteableName() { return NAME; } + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RANK_DOCS_RETRIEVER; + } + @Override public final int compareTo(RankDoc other) { if (score != other.score) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 1962145d7336d..22bef026523e9 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -160,7 +160,7 @@ public void onFailure(Exception e) { @Override public final QueryBuilder topDocsQuery() { - throw new IllegalStateException(getName() + " cannot be nested"); + throw new IllegalStateException("Should not be called, missing a rewrite?"); } @Override @@ -208,7 +208,7 @@ public int doHashCode() { return Objects.hash(innerRetrievers); } - private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { + protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) .trackTotalHits(false) .storedFields(new StoredFieldsContext(false)) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 1328106896bcb..1c6f8c4a7ce44 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -251,11 +251,19 @@ public ActionRequestValidationException validate( @Override public final XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); + builder.startObject(getName()); if (preFilterQueryBuilders.isEmpty() == false) { builder.field(PRE_FILTER_FIELD.getPreferredName(), preFilterQueryBuilders); } + if (minScore != null) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); + } + if (retrieverName != null) { + builder.field(NAME_FIELD.getPreferredName(), retrieverName); + } doToXContent(builder, params); builder.endObject(); + builder.endObject(); return builder; } diff --git a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java index a38a24eb75fca..be1ce9c925037 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java +++ b/server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java @@ -64,7 +64,7 @@ public void setTopValue(Long value) { @Override public Long value(int slot) { - return (((long) shardRequestIndex) << 32) | (delegate.value(slot) & 0xFFFFFFFFL); + return encodeShardAndDoc(shardRequestIndex, delegate.value(slot)); } @Override @@ -87,4 +87,8 @@ public static int decodeDoc(long value) { public static int decodeShardRequestIndex(long value) { return (int) (value >> 32); } + + public static long encodeShardAndDoc(int shardIndex, int doc) { + return (((long) shardIndex) << 32) | (doc & 0xFFFFFFFFL); + } } diff --git a/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification index 143c0293c5ab7..5cd8935f72403 100644 --- a/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification +++ b/server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification @@ -16,6 +16,7 @@ org.elasticsearch.rest.RestFeatures org.elasticsearch.indices.IndicesFeatures org.elasticsearch.repositories.RepositoriesFeatures org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures +org.elasticsearch.rest.action.admin.cluster.ClusterRerouteFeatures org.elasticsearch.index.mapper.MapperFeatures org.elasticsearch.ingest.IngestGeoIpFeatures org.elasticsearch.search.SearchFeatures diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java index 21b30557cafea..6a7f4bb27a324 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/IndexRoutingTableTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.routing; +import org.elasticsearch.TransportVersion; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; @@ -19,6 +20,7 @@ import java.util.List; +import static org.elasticsearch.TransportVersions.FAST_REFRESH_RCO; import static org.elasticsearch.index.IndexSettings.INDEX_FAST_REFRESH_SETTING; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -27,16 +29,22 @@ public class IndexRoutingTableTests extends ESTestCase { public void testReadyForSearch() { - innerReadyForSearch(false); - innerReadyForSearch(true); + innerReadyForSearch(false, false); + innerReadyForSearch(false, true); + innerReadyForSearch(true, false); + innerReadyForSearch(true, true); } - private void innerReadyForSearch(boolean fastRefresh) { + // TODO: remove if (fastRefresh && beforeFastRefreshRCO) branches (ES-9563) + private void innerReadyForSearch(boolean fastRefresh, boolean beforeFastRefreshRCO) { Index index = new Index(randomIdentifier(), UUIDs.randomBase64UUID()); ClusterState clusterState = mock(ClusterState.class, Mockito.RETURNS_DEEP_STUBS); when(clusterState.metadata().index(any(Index.class)).getSettings()).thenReturn( Settings.builder().put(INDEX_FAST_REFRESH_SETTING.getKey(), fastRefresh).build() ); + when(clusterState.getMinTransportVersion()).thenReturn( + beforeFastRefreshRCO ? TransportVersion.fromId(FAST_REFRESH_RCO.id() - 1_00_0) : TransportVersion.current() + ); // 2 primaries that are search and index ShardId p1 = new ShardId(index, 0); IndexShardRoutingTable shardTable1 = new IndexShardRoutingTable( @@ -55,7 +63,7 @@ private void innerReadyForSearch(boolean fastRefresh) { shardTable1 = new IndexShardRoutingTable(p1, List.of(getShard(p1, true, ShardRoutingState.STARTED, ShardRouting.Role.INDEX_ONLY))); shardTable2 = new IndexShardRoutingTable(p2, List.of(getShard(p2, true, ShardRoutingState.STARTED, ShardRouting.Role.INDEX_ONLY))); indexRoutingTable = new IndexRoutingTable(index, new IndexShardRoutingTable[] { shardTable1, shardTable2 }); - if (fastRefresh) { + if (fastRefresh && beforeFastRefreshRCO) { assertTrue(indexRoutingTable.readyForSearch(clusterState)); } else { assertFalse(indexRoutingTable.readyForSearch(clusterState)); @@ -91,7 +99,7 @@ private void innerReadyForSearch(boolean fastRefresh) { ) ); indexRoutingTable = new IndexRoutingTable(index, new IndexShardRoutingTable[] { shardTable1, shardTable2 }); - if (fastRefresh) { + if (fastRefresh && beforeFastRefreshRCO) { assertTrue(indexRoutingTable.readyForSearch(clusterState)); } else { assertFalse(indexRoutingTable.readyForSearch(clusterState)); @@ -118,8 +126,6 @@ private void innerReadyForSearch(boolean fastRefresh) { assertTrue(indexRoutingTable.readyForSearch(clusterState)); // 2 unassigned primaries that are index only with some replicas that are all available - // Fast refresh indices do not support replicas so this can not practically happen. If we add support we will want to ensure - // that readyForSearch allows for searching replicas when the index shard is not available. shardTable1 = new IndexShardRoutingTable( p1, List.of( @@ -137,8 +143,8 @@ private void innerReadyForSearch(boolean fastRefresh) { ) ); indexRoutingTable = new IndexRoutingTable(index, new IndexShardRoutingTable[] { shardTable1, shardTable2 }); - if (fastRefresh) { - assertFalse(indexRoutingTable.readyForSearch(clusterState)); // if we support replicas for fast refreshes this needs to change + if (fastRefresh && beforeFastRefreshRCO) { + assertFalse(indexRoutingTable.readyForSearch(clusterState)); } else { assertTrue(indexRoutingTable.readyForSearch(clusterState)); } diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java index 9d889f24acb6c..c20d84fcf4b10 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java @@ -181,7 +181,7 @@ public void testFailedAllocation() { decision.getExplanation(), allOf( containsString("shard has exceeded the maximum number of retries"), - containsString("POST /_cluster/reroute?retry_failed&metric=none") + containsString("POST /_cluster/reroute?retry_failed") ) ); } @@ -280,7 +280,7 @@ public void testFailedRelocation() { decision.getExplanation(), allOf( containsString("shard has exceeded the maximum number of retries"), - containsString("POST /_cluster/reroute?retry_failed&metric=none") + containsString("POST /_cluster/reroute?retry_failed") ) ); }); diff --git a/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java b/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java index 77635fd0312f8..4cb3ce418f761 100644 --- a/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java +++ b/server/src/test/java/org/elasticsearch/index/cache/bitset/BitSetFilterCacheTests.java @@ -276,7 +276,7 @@ public void testShouldLoadRandomAccessFiltersEagerly() { for (var isStateless : values) { if (isStateless) { assertEquals( - loadFiltersEagerly && indexFastRefresh && hasIndexRole, + loadFiltersEagerly && indexFastRefresh && hasIndexRole == false, BitsetFilterCache.shouldLoadRandomAccessFiltersEagerly( bitsetFilterCacheSettings(isStateless, hasIndexRole, loadFiltersEagerly, indexFastRefresh) ) diff --git a/server/src/test/java/org/elasticsearch/index/query/SpanGapQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/SpanGapQueryBuilderTests.java index cef43a635541e..5adca6d562dca 100644 --- a/server/src/test/java/org/elasticsearch/index/query/SpanGapQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/SpanGapQueryBuilderTests.java @@ -13,6 +13,7 @@ import org.apache.lucene.queries.spans.SpanQuery; import org.apache.lucene.queries.spans.SpanTermQuery; import org.apache.lucene.search.Query; +import org.elasticsearch.lucene.queries.SpanMatchNoDocsQuery; import org.elasticsearch.test.AbstractQueryTestCase; import java.io.IOException; @@ -50,7 +51,9 @@ protected SpanNearQueryBuilder doCreateTestQueryBuilder() { protected void doAssertLuceneQuery(SpanNearQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException { assertThat( query, - either(instanceOf(SpanNearQuery.class)).or(instanceOf(SpanTermQuery.class)).or(instanceOf(MatchAllQueryBuilder.class)) + either(instanceOf(SpanNearQuery.class)).or(instanceOf(SpanTermQuery.class)) + .or(instanceOf(MatchAllQueryBuilder.class)) + .or(instanceOf(SpanMatchNoDocsQuery.class)) ); if (query instanceof SpanNearQuery spanNearQuery) { assertThat(spanNearQuery.getSlop(), equalTo(queryBuilder.slop())); diff --git a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java index 3f33bbfe6f6cb..240a677f4cbfd 100644 --- a/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java @@ -41,6 +41,8 @@ import org.elasticsearch.search.collapse.CollapseBuilderTests; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; @@ -600,6 +602,75 @@ public void testNegativeTrackTotalHits() throws IOException { } } + public void testStandardRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"standard\": {" + + " \"query\": {" + + " \"match_all\": {}" + + " }," + + " \"min_score\": 10," + + " \"_name\": \"foo_standard\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(StandardRetrieverBuilder.class)); + StandardRetrieverBuilder parsed = (StandardRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(10f)); + assertThat(parsed.retrieverName(), equalTo("foo_standard")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(StandardRetrieverBuilder.class)); + StandardRetrieverBuilder deserialized = (StandardRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } + + public void testKnnRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"knn\": {" + + " \"query_vector\": [" + + " 3" + + " ]," + + " \"field\": \"vector\"," + + " \"k\": 10," + + " \"num_candidates\": 15," + + " \"min_score\": 10," + + " \"_name\": \"foo_knn\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(KnnRetrieverBuilder.class)); + KnnRetrieverBuilder parsed = (KnnRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(10f)); + assertThat(parsed.retrieverName(), equalTo("foo_knn")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(KnnRetrieverBuilder.class)); + KnnRetrieverBuilder deserialized = (KnnRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } + public void testStoredFieldsUsage() throws IOException { Set storedFieldRestVariations = Set.of( "{\"stored_fields\" : [\"_none_\"]}", diff --git a/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java b/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java new file mode 100644 index 0000000000000..d0c85a33acf09 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; + +public abstract class AbstractRankDocWireSerializingTestCase extends AbstractWireSerializingTestCase { + + protected abstract T createTestRankDoc(); + + @Override + protected NamedWriteableRegistry writableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); + List entries = searchModule.getNamedWriteables(); + entries.addAll(getAdditionalNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + protected abstract List getAdditionalNamedWriteables(); + + @Override + protected T createTestInstance() { + return createTestRankDoc(); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testRankDocSerialization() throws IOException { + int totalDocs = randomIntBetween(10, 100); + Set docs = new HashSet<>(); + for (int i = 0; i < totalDocs; i++) { + docs.add(createTestRankDoc()); + } + RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean()); + RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class); + assertThat(rankDocsQueryBuilder, equalTo(copy)); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java b/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java index d190139309c31..21101b2bc7db1 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java +++ b/server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java @@ -9,27 +9,29 @@ package org.elasticsearch.search.rank; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import java.io.IOException; +import java.util.Collections; +import java.util.List; -public class RankDocTests extends AbstractWireSerializingTestCase { +public class RankDocTests extends AbstractRankDocWireSerializingTestCase { - static RankDoc createTestRankDoc() { + protected RankDoc createTestRankDoc() { RankDoc rankDoc = new RankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1)); rankDoc.rank = randomNonNegativeInt(); return rankDoc; } @Override - protected Writeable.Reader instanceReader() { - return RankDoc::new; + protected List getAdditionalNamedWriteables() { + return Collections.emptyList(); } @Override - protected RankDoc createTestInstance() { - return createTestRankDoc(); + protected Writeable.Reader instanceReader() { + return RankDoc::new; } @Override diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index f3dd86e0b1fa2..b0bf7e6636498 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -74,7 +74,7 @@ protected KnnRetrieverBuilder createTestInstance() { @Override protected KnnRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { - return KnnRetrieverBuilder.fromXContent( + return (KnnRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), diff --git a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java index d2a1cac43c154..eacd949077bc4 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java @@ -98,7 +98,7 @@ protected StandardRetrieverBuilder createTestInstance() { @Override protected StandardRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { - return StandardRetrieverBuilder.fromXContent( + return (StandardRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index e45ac8a9e0f70..a24bd91206ac0 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -355,7 +355,6 @@ public void testManyEval() throws IOException { assertMap(map, mapMatcher.entry("columns", columns).entry("values", hasSize(10_000)).entry("took", greaterThanOrEqualTo(0))); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch-serverless/issues/1874") public void testTooManyEval() throws IOException { initManyLongs(); assertCircuitBreaks(() -> manyEval(490)); @@ -616,14 +615,13 @@ private void initMvLongsIndex(int docs, int fields, int fieldValues) throws IOEx private void bulk(String name, String bulk) throws IOException { Request request = new Request("POST", "/" + name + "/_bulk"); - request.addParameter("filter_path", "errors"); request.setJsonEntity(bulk); request.setOptions( RequestOptions.DEFAULT.toBuilder() .setRequestConfig(RequestConfig.custom().setSocketTimeout(Math.toIntExact(TimeValue.timeValueMinutes(5).millis())).build()) ); Response response = client().performRequest(request); - assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("{\"errors\":false}")); + assertThat(entityAsMap(response), matchesMap().entry("errors", false).extraOk()); } private void initIndex(String name, String bulk) throws IOException { diff --git a/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultMappingParametersHandler.java b/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultMappingParametersHandler.java index 1046e22e65caa..81bd80f464525 100644 --- a/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultMappingParametersHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/logsdb/datageneration/datasource/DefaultMappingParametersHandler.java @@ -98,7 +98,22 @@ public DataSourceResponse.ObjectMappingParametersGenerator handle(DataSourceRequ return new DataSourceResponse.ObjectMappingParametersGenerator(() -> { var parameters = new HashMap(); - if (request.parentSubobjects() == ObjectMapper.Subobjects.DISABLED) { + // Changing subobjects from subobjects: false is not supported, but we can f.e. go from "true" to "false". + // TODO enable subobjects: auto + // It is disabled because it currently does not have auto flattening and that results in asserts being triggered when using + // copy_to. + if (ESTestCase.randomBoolean()) { + parameters.put( + "subobjects", + ESTestCase.randomValueOtherThan( + ObjectMapper.Subobjects.AUTO, + () -> ESTestCase.randomFrom(ObjectMapper.Subobjects.values()) + ).toString() + ); + } + + if (request.parentSubobjects() == ObjectMapper.Subobjects.DISABLED + || parameters.getOrDefault("subobjects", "true").equals("false")) { // "enabled: false" is not compatible with subobjects: false // changing "dynamic" from parent context is not compatible with subobjects: false // changing subobjects value is not compatible with subobjects: false @@ -115,19 +130,6 @@ public DataSourceResponse.ObjectMappingParametersGenerator handle(DataSourceRequ if (ESTestCase.randomBoolean()) { parameters.put("enabled", ESTestCase.randomFrom("true", "false")); } - // Changing subobjects from subobjects: false is not supported, but we can f.e. go from "true" to "false". - // TODO enable subobjects: auto - // It is disabled because it currently does not have auto flattening and that results in asserts being triggered when using - // copy_to. - if (ESTestCase.randomBoolean()) { - parameters.put( - "subobjects", - ESTestCase.randomValueOtherThan( - ObjectMapper.Subobjects.AUTO, - () -> ESTestCase.randomFrom(ObjectMapper.Subobjects.values()) - ).toString() - ); - } if (ESTestCase.randomBoolean()) { var value = request.isRoot() ? ESTestCase.randomFrom("none", "arrays") : ESTestCase.randomFrom("none", "arrays", "all"); diff --git a/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/SearchableSnapshotActionIT.java b/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/SearchableSnapshotActionIT.java index fefeaa95319ed..f00b5b566c156 100644 --- a/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/SearchableSnapshotActionIT.java +++ b/x-pack/plugin/ilm/qa/multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ilm/actions/SearchableSnapshotActionIT.java @@ -982,7 +982,7 @@ public void testSearchableSnapshotTotalShardsPerNode() throws Exception { * notification that partial-index is now GREEN. */ private void triggerStateChange() throws IOException { - Request rerouteRequest = new Request("POST", "/_cluster/reroute?metric=none"); + Request rerouteRequest = new Request("POST", "/_cluster/reroute"); client().performRequest(rerouteRequest); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 30ccb48d5c709..a3f2105054639 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -27,7 +27,8 @@ public Set getFeatures() { TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED, RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED, SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID, - SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS + SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS, + TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index dbb9130ab91e1..927fd94809886 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -36,6 +36,7 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xcontent.ParseField; @@ -66,6 +67,7 @@ import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; +import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction; @@ -253,6 +255,7 @@ public List getNamedWriteables() { var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables()); entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new)); entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new)); return entries; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 81ebebdb47e4f..3ae8dc0550391 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -16,10 +16,13 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import java.util.ArrayList; import java.util.List; @@ -42,7 +45,8 @@ public class EmbeddingRequestChunker { public enum EmbeddingType { FLOAT, - BYTE; + BYTE, + SPARSE; public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.ElementType elementType) { return switch (elementType) { @@ -67,6 +71,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private List> chunkedInputs; private List>> floatResults; private List>> byteResults; + private List>> sparseResults; private AtomicArray errors; private ActionListener> finalListener; @@ -117,6 +122,7 @@ private void splitIntoBatchedRequests(List inputs) { switch (embeddingType) { case FLOAT -> floatResults = new ArrayList<>(inputs.size()); case BYTE -> byteResults = new ArrayList<>(inputs.size()); + case SPARSE -> sparseResults = new ArrayList<>(inputs.size()); } errors = new AtomicArray<>(inputs.size()); @@ -127,6 +133,7 @@ private void splitIntoBatchedRequests(List inputs) { switch (embeddingType) { case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches)); case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches)); + case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches)); } chunkedInputs.add(chunks); } @@ -217,6 +224,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) { switch (embeddingType) { case FLOAT -> handleFloatResults(inferenceServiceResults); case BYTE -> handleByteResults(inferenceServiceResults); + case SPARSE -> handleSparseResults(inferenceServiceResults); } } @@ -266,6 +274,29 @@ private void handleByteResults(InferenceServiceResults inferenceServiceResults) } } + private void handleSparseResults(InferenceServiceResults inferenceServiceResults) { + if (inferenceServiceResults instanceof SparseEmbeddingResults sparseEmbeddings) { + if (failIfNumRequestsDoNotMatch(sparseEmbeddings.embeddings().size())) { + return; + } + + int start = 0; + for (var pos : positions) { + sparseResults.get(pos.inputIndex()) + .setOnce(pos.chunkIndex(), sparseEmbeddings.embeddings().subList(start, start + pos.embeddingCount())); + start += pos.embeddingCount(); + } + + if (resultCount.incrementAndGet() == totalNumberOfRequests) { + sendResponse(); + } + } else { + onFailure( + unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), InferenceTextEmbeddingByteResults.NAME) + ); + } + } + private boolean failIfNumRequestsDoNotMatch(int numberOfResults) { int numberOfRequests = positions.stream().mapToInt(SubBatchPositionsAndCount::embeddingCount).sum(); if (numberOfRequests != numberOfResults) { @@ -319,6 +350,7 @@ private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex)); case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex)); }; } @@ -366,6 +398,26 @@ private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs( return new InferenceChunkedTextEmbeddingByteResults(embeddingChunks, false); } + private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs( + List chunks, + AtomicArray> debatchedResults + ) { + var all = new ArrayList(); + for (int i = 0; i < debatchedResults.length(); i++) { + var subBatch = debatchedResults.get(i); + all.addAll(subBatch); + } + + assert chunks.size() == all.size(); + + var embeddingChunks = new ArrayList(); + for (int i = 0; i < chunks.size(); i++) { + embeddingChunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunks.get(i), all.get(i).tokens())); + } + + return new InferenceChunkedSparseEmbeddingResults(embeddingChunks); + } + public record BatchRequest(List subBatches) { public int size() { return subBatches.stream().mapToInt(SubBatch::size).sum(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 9af5668ecf75b..fc2d890dd89e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -21,7 +21,7 @@ public abstract class DelegatingProcessor implements Flow.Processor { private static final Logger log = LogManager.getLogger(DelegatingProcessor.class); private final AtomicLong pendingRequests = new AtomicLong(); - private final AtomicBoolean isClosed = new AtomicBoolean(false); + protected final AtomicBoolean isClosed = new AtomicBoolean(false); private Flow.Subscriber downstream; private Flow.Subscription upstream; @@ -49,7 +49,7 @@ private Flow.Subscription forwardingSubscription() { @Override public void request(long n) { if (isClosed.get()) { - downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening + downstream.onComplete(); } else if (upstream != null) { upstream.request(n); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java index b5af0b474834f..3579cd4100bbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandler.java @@ -8,14 +8,19 @@ package org.elasticsearch.xpack.inference.external.cohere; import org.apache.logging.log4j.Logger; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.http.retry.RetryException; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.cohere.CohereErrorResponseEntity; +import org.elasticsearch.xpack.inference.external.response.streaming.NewlineDelimitedByteProcessor; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import java.util.concurrent.Flow; + import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; /** @@ -33,9 +38,11 @@ public class CohereResponseHandler extends BaseResponseHandler { static final String TEXTS_ARRAY_TOO_LARGE_MESSAGE_MATCHER = "invalid request: total number of texts must be at most"; static final String TEXTS_ARRAY_ERROR_MESSAGE = "Received a texts array too large response"; + private final boolean canHandleStreamingResponse; - public CohereResponseHandler(String requestType, ResponseParser parseFunction) { + public CohereResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponse) { super(requestType, parseFunction, CohereErrorResponseEntity::fromResponse); + this.canHandleStreamingResponse = canHandleStreamingResponse; } @Override @@ -45,6 +52,20 @@ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, R checkForEmptyBody(throttlerManager, logger, request, result); } + @Override + public boolean canHandleStreamingResponses() { + return canHandleStreamingResponse; + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var ndProcessor = new NewlineDelimitedByteProcessor(); + var cohereProcessor = new CohereStreamingProcessor(); + flow.subscribe(ndProcessor); + ndProcessor.subscribe(cohereProcessor); + return new StreamingChatCompletionResults(cohereProcessor); + } + /** * Validates the status code throws an RetryException if not in the range [200, 300). * diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java new file mode 100644 index 0000000000000..2516a647a91fd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessor.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Map; +import java.util.Optional; + +class CohereStreamingProcessor extends DelegatingProcessor, StreamingChatCompletionResults.Results> { + private static final Logger log = LogManager.getLogger(CohereStreamingProcessor.class); + + @Override + protected void next(Deque item) throws Exception { + if (item.isEmpty()) { + // discard empty result and go to the next + upstream().request(1); + return; + } + + var results = new ArrayDeque(item.size()); + for (String json : item) { + try (var jsonParser = jsonParser(json)) { + var responseMap = jsonParser.map(); + var eventType = (String) responseMap.get("event_type"); + switch (eventType) { + case "text-generation" -> parseText(responseMap).ifPresent(results::offer); + case "stream-end" -> validateResponse(responseMap); + case "stream-start", "search-queries-generation", "search-results", "citation-generation", "tool-calls-generation", + "tool-calls-chunk" -> { + log.debug("Skipping event type [{}] for line [{}].", eventType, item); + } + default -> throw new IOException("Unknown eventType found: " + eventType); + } + } catch (ElasticsearchStatusException e) { + throw e; + } catch (Exception e) { + log.warn("Failed to parse json from cohere: {}", json); + throw e; + } + } + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(new StreamingChatCompletionResults.Results(results)); + } + } + + private static XContentParser jsonParser(String line) throws IOException { + return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line); + } + + private Optional parseText(Map responseMap) throws IOException { + var text = (String) responseMap.get("text"); + if (text != null) { + return Optional.of(new StreamingChatCompletionResults.Result(text)); + } else { + throw new IOException("Null text found in text-generation cohere event"); + } + } + + private void validateResponse(Map responseMap) { + var finishReason = (String) responseMap.get("finish_reason"); + switch (finishReason) { + case "ERROR", "ERROR_TOXIC" -> throw new ElasticsearchStatusException( + "Cohere stopped the stream due to an error: {}", + RestStatus.INTERNAL_SERVER_ERROR, + parseErrorMessage(responseMap) + ); + case "ERROR_LIMIT" -> throw new ElasticsearchStatusException( + "Cohere stopped the stream due to an error: {}", + RestStatus.TOO_MANY_REQUESTS, + parseErrorMessage(responseMap) + ); + } + } + + @SuppressWarnings("unchecked") + private String parseErrorMessage(Map responseMap) { + var innerResponseMap = (Map) responseMap.get("response"); + return (String) innerResponseMap.get("text"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index 423093a14a9f0..ae46fbe0fef87 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.cohere.CohereCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -30,7 +29,7 @@ public class CohereCompletionRequestManager extends CohereRequestManager { private static final ResponseHandler HANDLER = createCompletionHandler(); private static ResponseHandler createCompletionHandler() { - return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse); + return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse, true); } public static CohereCompletionRequestManager of(CohereCompletionModel model, ThreadPool threadPool) { @@ -51,8 +50,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); - CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model); + var docsOnly = DocumentsOnlyInput.of(inferenceInputs); + var docsInput = docsOnly.getInputs(); + var stream = docsOnly.stream(); + CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index 402f91a0838dc..80617ea56e63c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -28,7 +28,7 @@ public class CohereEmbeddingsRequestManager extends CohereRequestManager { private static final ResponseHandler HANDLER = createEmbeddingsHandler(); private static ResponseHandler createEmbeddingsHandler() { - return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse); + return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse, false); } public static CohereEmbeddingsRequestManager of(CohereEmbeddingsModel model, ThreadPool threadPool) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 9d565e7124b03..d27812b17399b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -27,7 +27,7 @@ public class CohereRerankRequestManager extends CohereRequestManager { private static final ResponseHandler HANDLER = createCohereResponseHandler(); private static ResponseHandler createCohereResponseHandler() { - return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response)); + return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response), false); } public static CohereRerankRequestManager of(CohereRerankModel model, ThreadPool threadPool) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java index f68f919a7d85b..2172dcd4d791f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequest.java @@ -25,22 +25,20 @@ import java.util.Objects; public class CohereCompletionRequest extends CohereRequest { - private final CohereAccount account; - private final List input; - private final String modelId; - private final String inferenceEntityId; + private final boolean stream; - public CohereCompletionRequest(List input, CohereCompletionModel model) { + public CohereCompletionRequest(List input, CohereCompletionModel model, boolean stream) { Objects.requireNonNull(model); this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); this.modelId = model.getServiceSettings().modelId(); this.inferenceEntityId = model.getInferenceEntityId(); + this.stream = stream; } @Override @@ -48,7 +46,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereCompletionRequestEntity(input, modelId)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -62,6 +60,11 @@ public String getInferenceEntityId() { return inferenceEntityId; } + @Override + public boolean isStreaming() { + return stream; + } + @Override public URI getURI() { return account.uri(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java index 8cb3dc6e3c8e8..b834e4335d73c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/completion/CohereCompletionRequestEntity.java @@ -15,11 +15,11 @@ import java.util.List; import java.util.Objects; -public record CohereCompletionRequestEntity(List input, @Nullable String model) implements ToXContentObject { +public record CohereCompletionRequestEntity(List input, @Nullable String model, boolean stream) implements ToXContentObject { private static final String MESSAGE_FIELD = "message"; - private static final String MODEL = "model"; + private static final String STREAM = "stream"; public CohereCompletionRequestEntity { Objects.requireNonNull(input); @@ -36,6 +36,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL, model); } + if (stream) { + builder.field(STREAM, true); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java new file mode 100644 index 0000000000000..7c44b202a816b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessor.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.regex.Pattern; + +/** + * Processes HttpResult bytes into lines separated by newlines, delimited by either line-feed or carriage-return line-feed. + * Downstream is responsible for validating the structure of the lines after they have been separated. + * Because Upstream (Apache) can send us a single line split between two HttpResults, this processor will aggregate bytes from the last + * HttpResult and append them to the front of the next HttpResult. + * When onComplete is called, the last batch is always flushed to the downstream onNext. + */ +public class NewlineDelimitedByteProcessor extends DelegatingProcessor> { + private static final Pattern END_OF_LINE_REGEX = Pattern.compile("\\n|\\r\\n"); + private volatile String previousTokens = ""; + + @Override + protected void next(HttpResult item) { + // discard empty result and go to the next + if (item.isBodyEmpty()) { + upstream().request(1); + return; + } + + var body = previousTokens + new String(item.body(), StandardCharsets.UTF_8); + var lines = END_OF_LINE_REGEX.split(body, -1); // -1 because we actually want trailing empty strings + + var results = new ArrayDeque(lines.length); + for (var i = 0; i < lines.length - 1; i++) { + var line = lines[i].trim(); + if (line.isBlank() == false) { + results.offer(line); + } + } + + previousTokens = lines[lines.length - 1].trim(); + + if (results.isEmpty()) { + upstream().request(1); + } else { + downstream().onNext(results); + } + } + + @Override + public void onComplete() { + if (previousTokens.isBlank()) { + super.onComplete(); + } else if (isClosed.compareAndSet(false, true)) { + var results = new ArrayDeque(1); + results.offer(previousTokens); + downstream().onNext(results); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java index eb36c445506a7..134f8af0e083d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilder.java @@ -103,10 +103,7 @@ public int rankWindowSize() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVER_FIELD.getPreferredName()); - builder.startObject(); - builder.field(retrieverBuilder.getName(), retrieverBuilder); - builder.endObject(); + builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder); builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); if (seed != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java new file mode 100644 index 0000000000000..d208623e53324 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class TextSimilarityRankDoc extends RankDoc { + + public static final String NAME = "text_similarity_rank_doc"; + + public final String inferenceId; + public final String field; + + public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) { + super(doc, score, shardIndex); + this.inferenceId = inferenceId; + this.field = field; + } + + public TextSimilarityRankDoc(StreamInput in) throws IOException { + super(in); + inferenceId = in.readString(); + field = in.readString(); + } + + @Override + public Explanation explain(Explanation[] sources, String[] queryNames) { + final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]"; + return Explanation.match( + score, + "text_similarity_reranker match using inference endpoint: [" + + inferenceId + + "] on document field: [" + + field + + "] matching on source query " + + queryAlias, + sources + ); + } + + @Override + public void doWriteTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + out.writeString(field); + } + + @Override + public boolean doEquals(RankDoc rd) { + TextSimilarityRankDoc tsrd = (TextSimilarityRankDoc) rd; + return Objects.equals(inferenceId, tsrd.inferenceId) && Objects.equals(field, tsrd.field); + } + + @Override + public int doHashCode() { + return Objects.hash(inferenceId, field); + } + + @Override + public String toString() { + return "TextSimilarityRankDoc{" + + "doc=" + + doc + + ", shardIndex=" + + shardIndex + + ", score=" + + score + + ", inferenceId=" + + inferenceId + + ", field=" + + field + + '}'; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("inferenceId", inferenceId); + builder.field("field", field); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index ab013e0275a69..3ddaab12eca14 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -7,14 +7,20 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; +import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.StoredFieldsContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -32,11 +38,14 @@ /** * A {@code RetrieverBuilder} for parsing and constructing a text similarity reranker retriever. */ -public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { +public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder { public static final NodeFeature TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED = new NodeFeature( "text_similarity_reranker_retriever_supported" ); + public static final NodeFeature TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED = new NodeFeature( + "text_similarity_reranker_retriever_composition_supported" + ); public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id"); @@ -51,7 +60,6 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder { String inferenceText = (String) args[2]; String field = (String) args[3]; int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4]; - return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize); }); @@ -70,17 +78,20 @@ public static TextSimilarityRankRetrieverBuilder fromXContent(XContentParser par if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED) == false) { throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + TextSimilarityRankBuilder.NAME + "]"); } + if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED) == false) { + throw new UnsupportedOperationException( + "[text_similarity_reranker] retriever composition feature is not supported by all nodes in the cluster" + ); + } if (TextSimilarityRankBuilder.TEXT_SIMILARITY_RERANKER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) { throw LicenseUtils.newComplianceException(TextSimilarityRankBuilder.NAME); } return PARSER.apply(parser, context); } - private final RetrieverBuilder retrieverBuilder; private final String inferenceId; private final String inferenceText; private final String field; - private final int rankWindowSize; public TextSimilarityRankRetrieverBuilder( RetrieverBuilder retrieverBuilder, @@ -89,15 +100,14 @@ public TextSimilarityRankRetrieverBuilder( String field, int rankWindowSize ) { - this.retrieverBuilder = retrieverBuilder; + super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize); this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; - this.rankWindowSize = rankWindowSize; } public TextSimilarityRankRetrieverBuilder( - RetrieverBuilder retrieverBuilder, + List retrieverSource, String inferenceId, String inferenceText, String field, @@ -106,66 +116,75 @@ public TextSimilarityRankRetrieverBuilder( String retrieverName, List preFilterQueryBuilders ) { - this.retrieverBuilder = retrieverBuilder; + super(retrieverSource, rankWindowSize); + if (retrieverSource.size() != 1) { + throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever"); + } this.inferenceId = inferenceId; this.inferenceText = inferenceText; this.field = field; - this.rankWindowSize = rankWindowSize; this.minScore = minScore; this.retrieverName = retrieverName; this.preFilterQueryBuilders = preFilterQueryBuilders; } @Override - public QueryBuilder topDocsQuery() { - // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever - return retrieverBuilder.topDocsQuery(); + protected TextSimilarityRankRetrieverBuilder clone(List newChildRetrievers) { + return new TextSimilarityRankRetrieverBuilder( + newChildRetrievers, + inferenceId, + inferenceText, + field, + rankWindowSize, + minScore, + retrieverName, + preFilterQueryBuilders + ); } @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - // rewrite prefilters - boolean hasChanged = false; - var newPreFilters = rewritePreFilters(ctx); - hasChanged |= newPreFilters != preFilterQueryBuilders; - - // rewrite nested retriever - RetrieverBuilder newRetriever = retrieverBuilder.rewrite(ctx); - hasChanged |= newRetriever != retrieverBuilder; - if (hasChanged) { - return new TextSimilarityRankRetrieverBuilder( - newRetriever, - field, - inferenceText, - inferenceId, - rankWindowSize, - minScore, - this.retrieverName, - newPreFilters - ); + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + assert rankResults.size() == 1; + ScoreDoc[] scoreDocs = rankResults.getFirst(); + TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + ScoreDoc scoreDoc = scoreDocs[i]; + textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field); } - return this; + return textSimilarityRankDocs; } @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); - retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed); - // Combining with other rank builder (such as RRF) is not supported yet - if (searchSourceBuilder.rankBuilder() != null) { - throw new IllegalArgumentException("text similarity rank builder cannot be combined with other rank builders"); - } + public QueryBuilder explainQuery() { + // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever + return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true); + } - searchSourceBuilder.rankBuilder( + @Override + protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { + var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) + .trackTotalHits(false) + .storedFields(new StoredFieldsContext(false)) + .size(rankWindowSize); + if (preFilterQueryBuilders.isEmpty() == false) { + retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + } + retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); + + // apply the pre-filters + if (preFilterQueryBuilders.size() > 0) { + QueryBuilder query = sourceBuilder.query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + if (query != null) { + newQuery.must(query); + } + preFilterQueryBuilders.forEach(newQuery::filter); + sourceBuilder.query(newQuery); + } + sourceBuilder.rankBuilder( new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) ); - } - - /** - * Determines if this retriever contains sub-retrievers that need to be executed prior to search. - */ - public boolean isCompound() { - return retrieverBuilder.isCompound(); + return sourceBuilder; } @Override @@ -179,23 +198,17 @@ public int rankWindowSize() { @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVER_FIELD.getPreferredName()); - builder.startObject(); - builder.field(retrieverBuilder.getName(), retrieverBuilder); - builder.endObject(); + builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId); builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText); builder.field(FIELD_FIELD.getPreferredName(), field); builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); - if (minScore != null) { - builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); - } } @Override - protected boolean doEquals(Object other) { + public boolean doEquals(Object other) { TextSimilarityRankRetrieverBuilder that = (TextSimilarityRankRetrieverBuilder) other; - return Objects.equals(retrieverBuilder, that.retrieverBuilder) + return super.doEquals(other) && Objects.equals(inferenceId, that.inferenceId) && Objects.equals(inferenceText, that.inferenceText) && Objects.equals(field, that.field) @@ -204,7 +217,7 @@ protected boolean doEquals(Object other) { } @Override - protected int doHashCode() { - return Objects.hash(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, minScore); + public int doHashCode() { + return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 728a4ac137dff..3ba93dd8d1b66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -39,6 +39,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; @@ -288,4 +289,9 @@ static SimilarityMeasure defaultSimilarity() { public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ML_INFERENCE_RATE_LIMIT_SETTINGS_ADDED; } + + @Override + public Set supportedStreamingTasks() { + return COMPLETION_ONLY; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 0dd41db2f016c..881e2e82b766a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -248,15 +248,14 @@ public static InferModelAction.Request buildInferenceRequest( InferenceConfigUpdate update, List inputs, InputType inputType, - TimeValue timeout, - boolean chunk + TimeValue timeout ) { var request = InferModelAction.Request.forTextInput(id, update, inputs, true, timeout); request.setPrefixType( InputType.SEARCH == inputType ? TrainedModelPrefixStrings.PrefixType.SEARCH : TrainedModelPrefixStrings.PrefixType.INGEST ); request.setHighPriority(InputType.SEARCH == inputType); - request.setChunked(chunk); + request.setChunked(false); return request; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index 07d0cc14b2ac8..a593e1dfb6d9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -58,6 +58,11 @@ public abstract ActionListener getC ActionListener listener ); + @Override + public ElasticsearchInternalServiceSettings getServiceSettings() { + return (ElasticsearchInternalServiceSettings) super.getServiceSettings(); + } + @Override public String toString() { return Strings.toString(this.getConfigurations()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 9b4c0e50bdebe..739f514bee1c9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -28,21 +28,19 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -74,6 +72,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); + public static final int EMBEDDING_MAX_BATCH_SIZE = 10; public static final String DEFAULT_ELSER_ID = ".elser-2"; private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); @@ -501,8 +500,7 @@ public void inferTextEmbedding( TextEmbeddingConfigUpdate.EMPTY_INSTANCE, inputs, inputType, - timeout, - false + timeout ); ActionListener mlResultsListener = listener.delegateFailureAndWrap( @@ -528,8 +526,7 @@ public void inferSparseEmbedding( TextExpansionConfigUpdate.EMPTY_UPDATE, inputs, inputType, - timeout, - false + timeout ); ActionListener mlResultsListener = listener.delegateFailureAndWrap( @@ -557,8 +554,7 @@ public void inferRerank( new TextSimilarityConfigUpdate(query), inputs, inputType, - timeout, - false + timeout ); var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings(); @@ -610,52 +606,80 @@ public void chunkedInfer( if (model instanceof ElasticsearchInternalModel esModel) { - var configUpdate = chunkingOptions != null - ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : new TokenizationConfigUpdate(null, null); - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - configUpdate, + var batchedRequests = new EmbeddingRequestChunker( input, - inputType, - timeout, - true - ); + EMBEDDING_MAX_BATCH_SIZE, + embeddingTypeFromTaskTypeAndSettings(model.getTaskType(), esModel.internalServiceSettings) + ).batchRequestsWithListeners(listener); + + for (var batch : batchedRequests) { + var inferenceRequest = buildInferenceRequest( + model.getConfigurations().getInferenceEntityId(), + EmptyConfigUpdate.INSTANCE, + batch.batch().inputs(), + inputType, + timeout + ); - ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(translateToChunkedResults(inferenceResult.getInferenceResults())) - ); + ActionListener mlResultsListener = batch.listener() + .delegateFailureAndWrap( + (l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l) + ); - var maybeDeployListener = mlResultsListener.delegateResponse( - (l, exception) -> maybeStartDeployment(esModel, exception, request, mlResultsListener) - ); + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener) + ); - client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); + client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener); + } } else { listener.onFailure(notElasticsearchModelException(model)); } } - private static List translateToChunkedResults(List inferenceResults) { - var translated = new ArrayList(); - - for (var inferenceResult : inferenceResults) { - translated.add(translateToChunkedResult(inferenceResult)); - } - - return translated; - } + private static void translateToChunkedResult( + TaskType taskType, + List inferenceResults, + ActionListener chunkPartListener + ) { + if (taskType == TaskType.TEXT_EMBEDDING) { + var translated = new ArrayList(); - private static ChunkedInferenceServiceResults translateToChunkedResult(InferenceResults inferenceResult) { - if (inferenceResult instanceof MlChunkedTextEmbeddingFloatResults mlChunkedResult) { - return InferenceChunkedTextEmbeddingFloatResults.ofMlResults(mlChunkedResult); - } else if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult) { - return InferenceChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult); - } else if (inferenceResult instanceof ErrorInferenceResults error) { - return new ErrorChunkedInferenceResults(error.getException()); - } else { - throw createInvalidChunkedResultException(MlChunkedTextEmbeddingFloatResults.NAME, inferenceResult.getWriteableName()); + for (var inferenceResult : inferenceResults) { + if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult) { + translated.add( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(mlTextEmbeddingResult.getInferenceAsFloat()) + ); + } else if (inferenceResult instanceof ErrorInferenceResults error) { + chunkPartListener.onFailure(error.getException()); + return; + } else { + chunkPartListener.onFailure( + createInvalidChunkedResultException(MlTextEmbeddingResults.NAME, inferenceResult.getWriteableName()) + ); + return; + } + } + chunkPartListener.onResponse(new InferenceTextEmbeddingFloatResults(translated)); + } else { // sparse + var translated = new ArrayList(); + + for (var inferenceResult : inferenceResults) { + if (inferenceResult instanceof TextExpansionResults textExpansionResult) { + translated.add( + new SparseEmbeddingResults.Embedding(textExpansionResult.getWeightedTokens(), textExpansionResult.isTruncated()) + ); + } else if (inferenceResult instanceof ErrorInferenceResults error) { + chunkPartListener.onFailure(error.getException()); + return; + } else { + chunkPartListener.onFailure( + createInvalidChunkedResultException(TextExpansionResults.NAME, inferenceResult.getWriteableName()) + ); + return; + } + } + chunkPartListener.onResponse(new SparseEmbeddingResults(translated)); } } @@ -738,4 +762,21 @@ public List defaultConfigs() { protected boolean isDefaultId(String inferenceId) { return DEFAULT_ELSER_ID.equals(inferenceId); } + + static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings( + TaskType taskType, + ElasticsearchInternalServiceSettings serviceSettings + ) { + return switch (taskType) { + case SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE; + case TEXT_EMBEDDING -> serviceSettings.elementType() == null + ? EmbeddingRequestChunker.EmbeddingType.FLOAT + : EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(serviceSettings.elementType()); + default -> throw new ElasticsearchStatusException( + "Chunking is not supported for task type [{}]", + RestStatus.BAD_REQUEST, + taskType + ); + }; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index cf862ee6fb4b8..c1be537a6b0a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -11,10 +11,13 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.util.ArrayList; import java.util.List; @@ -357,6 +360,83 @@ public void testMergingListener_Byte() { } } + public void testMergingListener_Sparse() { + int batchSize = 4; + int chunkSize = 10; + int overlap = 0; + // passage will be chunked into 2.1 batches + // and spread over 3 batch requests + int numberOfWordsInPassage = (chunkSize * batchSize * 2) + 5; + + var passageBuilder = new StringBuilder(); + for (int i = 0; i < numberOfWordsInPassage; i++) { + passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace + } + List inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString()); + + var finalListener = testListener(); + var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap, EmbeddingRequestChunker.EmbeddingType.SPARSE) + .batchRequestsWithListeners(finalListener); + assertThat(batches, hasSize(3)); + + // 4 inputs in 3 batches + { + var embeddings = new ArrayList(); + for (int i = 0; i < batchSize; i++) { + embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(randomAlphaOfLength(4), 1.0f)), false)); + } + batches.get(0).listener().onResponse(new SparseEmbeddingResults(embeddings)); + } + { + var embeddings = new ArrayList(); + for (int i = 0; i < batchSize; i++) { + embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(randomAlphaOfLength(4), 1.0f)), false)); + } + batches.get(1).listener().onResponse(new SparseEmbeddingResults(embeddings)); + } + { + var embeddings = new ArrayList(); + for (int i = 0; i < 4; i++) { // 4 chunks in the final batch + embeddings.add(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(randomAlphaOfLength(4), 1.0f)), false)); + } + batches.get(2).listener().onResponse(new SparseEmbeddingResults(embeddings)); + } + + assertNotNull(finalListener.results); + assertThat(finalListener.results, hasSize(4)); + { + var chunkedResult = finalListener.results.get(0); + assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; + assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); + assertEquals("1st small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + } + { + var chunkedResult = finalListener.results.get(1); + assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; + assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); + assertEquals("2nd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + } + { + var chunkedResult = finalListener.results.get(2); + assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; + assertThat(chunkedSparseResult.getChunkedResults(), hasSize(1)); + assertEquals("3rd small", chunkedSparseResult.getChunkedResults().get(0).matchedText()); + } + { + // this is the large input split in multiple chunks + var chunkedResult = finalListener.results.get(3); + assertThat(chunkedResult, instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var chunkedSparseResult = (InferenceChunkedSparseEmbeddingResults) chunkedResult; + assertThat(chunkedSparseResult.getChunkedResults(), hasSize(9)); // passage is split into 9 chunks, 10 words each + assertThat(chunkedSparseResult.getChunkedResults().get(0).matchedText(), startsWith("passage_input0 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(1).matchedText(), startsWith(" passage_input10 ")); + assertThat(chunkedSparseResult.getChunkedResults().get(8).matchedText(), startsWith(" passage_input80 ")); + } + } + public void testListenerErrorsWithWrongNumberOfResponses() { List inputs = List.of("1st small", "2nd small", "3rd small"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java index d64ac495c8c99..444415dfc8e48 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereResponseHandlerTests.java @@ -132,7 +132,7 @@ private static void callCheckForFailureStatusCode(int statusCode, @Nullable Stri var mockRequest = mock(Request.class); when(mockRequest.getInferenceEntityId()).thenReturn(modelId); var httpResult = new HttpResult(httpResponse, errorMessage == null ? new byte[] {} : responseJson.getBytes(StandardCharsets.UTF_8)); - var handler = new CohereResponseHandler("", (request, result) -> null); + var handler = new CohereResponseHandler("", (request, result) -> null, false); handler.checkForFailureStatusCode(mockRequest, httpResult); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java new file mode 100644 index 0000000000000..87d6d63bb8c51 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/cohere/CohereStreamingProcessorTests.java @@ -0,0 +1,189 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.cohere; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onError; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class CohereStreamingProcessorTests extends ESTestCase { + + public void testParseErrorCallsOnError() { + var item = new ArrayDeque(); + item.offer("this is not json"); + + var exception = onError(new CohereStreamingProcessor(), item); + assertThat(exception, instanceOf(XContentParseException.class)); + } + + public void testUnrecognizedEventCallsOnError() { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"test\"}"); + + var exception = onError(new CohereStreamingProcessor(), item); + assertThat(exception, instanceOf(IOException.class)); + assertThat(exception.getMessage(), equalTo("Unknown eventType found: test")); + } + + public void testMissingTextCallsOnError() { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"text-generation\"}"); + + var exception = onError(new CohereStreamingProcessor(), item); + assertThat(exception, instanceOf(IOException.class)); + assertThat(exception.getMessage(), equalTo("Null text found in text-generation cohere event")); + } + + public void testEmptyResultsRequestsMoreData() throws Exception { + var emptyDeque = new ArrayDeque(); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(emptyDeque); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testNonDataEventsAreSkipped() throws Exception { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"stream-start\"}"); + item.offer("{\"event_type\":\"search-queries-generation\"}"); + item.offer("{\"event_type\":\"search-results\"}"); + item.offer("{\"event_type\":\"citation-generation\"}"); + item.offer("{\"event_type\":\"tool-calls-generation\"}"); + item.offer("{\"event_type\":\"tool-calls-chunk\"}"); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testParseError() { + var json = "{\"event_type\":\"stream-end\", \"finish_reason\":\"ERROR\", \"response\":{ \"text\": \"a wild error appears\" }}"; + testError(json, e -> { + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), containsString("a wild error appears")); + }); + } + + private void testError(String json, Consumer test) { + var item = new ArrayDeque(); + item.offer(json); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + try { + processor.next(item); + fail("Expected an exception to be thrown"); + } catch (ElasticsearchStatusException e) { + test.accept(e); + } catch (Exception e) { + fail(e, "Expected an exception of type ElasticsearchStatusException to be thrown"); + } + } + + public void testParseToxic() { + var json = "{\"event_type\":\"stream-end\", \"finish_reason\":\"ERROR_TOXIC\", \"response\":{ \"text\": \"by britney spears\" }}"; + testError(json, e -> { + assertThat(e.status().getStatus(), equalTo(500)); + assertThat(e.getMessage(), containsString("by britney spears")); + }); + } + + public void testParseLimit() { + var json = "{\"event_type\":\"stream-end\", \"finish_reason\":\"ERROR_LIMIT\", \"response\":{ \"text\": \"over the limit\" }}"; + testError(json, e -> { + assertThat(e.status().getStatus(), equalTo(429)); + assertThat(e.getMessage(), containsString("over the limit")); + }); + } + + public void testNonErrorFinishesAreSkipped() throws Exception { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"COMPLETE\"}"); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"STOP_SEQUENCE\"}"); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"USER_CANCEL\"}"); + item.offer("{\"event_type\":\"stream-end\", \"finish_reason\":\"MAX_TOKENS\"}"); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testParseCohereData() throws Exception { + var item = new ArrayDeque(); + item.offer("{\"event_type\":\"text-generation\", \"text\":\"hello there\"}"); + + var processor = new CohereStreamingProcessor(); + + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + + processor.next(item); + + verify(upstream, times(0)).request(1); + verify(downstream, times(1)).onNext(assertArg(chunks -> { + assertThat(chunks, isA(StreamingChatCompletionResults.Results.class)); + var results = (StreamingChatCompletionResults.Results) chunks; + assertThat(results.results().size(), equalTo(1)); + assertThat(results.results().getFirst().delta(), equalTo("hello there")); + })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java index dbe6a9438d884..c3b534f42e7ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java @@ -22,7 +22,7 @@ public class CohereCompletionRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), "model"); + var entity = new CohereCompletionRequestEntity(List.of("some input"), "model", false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -33,7 +33,7 @@ public void testXContent_WritesAllFields() throws IOException { } public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), null); + var entity = new CohereCompletionRequestEntity(List.of("some input"), null, false); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -44,10 +44,10 @@ public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { } public void testXContent_ThrowsIfInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null)); + expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null, false)); } public void testXContent_ThrowsIfMessageInInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null)); + expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null, false)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java index d6d0d5c00eaf4..f2e6d4305f9e6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java @@ -26,7 +26,7 @@ public class CohereCompletionRequestTests extends ESTestCase { public void testCreateRequest_UrlDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null)); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -43,7 +43,7 @@ public void testCreateRequest_UrlDefined() throws IOException { } public void testCreateRequest_ModelDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model")); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -60,14 +60,14 @@ public void testCreateRequest_ModelDefined() throws IOException { } public void testTruncate_ReturnsSameInstance() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model")); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } public void testTruncationInfo_ReturnsNull() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model")); + var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); assertNull(request.getTruncationInfo()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java new file mode 100644 index 0000000000000..488cbccd0e7c3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/streaming/NewlineDelimitedByteProcessorTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.nio.charset.StandardCharsets; +import java.util.Deque; +import java.util.concurrent.Flow; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class NewlineDelimitedByteProcessorTests extends ESTestCase { + private Flow.Subscription upstream; + private Flow.Subscriber> downstream; + private NewlineDelimitedByteProcessor processor; + + @Before + public void setUp() throws Exception { + super.setUp(); + upstream = mock(); + downstream = mock(); + processor = new NewlineDelimitedByteProcessor(); + processor.onSubscribe(upstream); + processor.subscribe(downstream); + } + + public void testEmptyBody() { + processor.next(result(null)); + processor.onComplete(); + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + private HttpResult result(String response) { + return new HttpResult(mock(), response == null ? new byte[0] : response.getBytes(StandardCharsets.UTF_8)); + } + + public void testEmptyParseResponse() { + processor.next(result("")); + verify(upstream, times(1)).request(1); + verify(downstream, times(0)).onNext(any()); + } + + public void testValidResponse() { + processor.next(result("{\"hello\":\"there\"}\n")); + verify(downstream, times(1)).onNext(assertArg(deque -> { + assertThat(deque, notNullValue()); + assertThat(deque.size(), is(1)); + assertThat(deque.getFirst(), is("{\"hello\":\"there\"}")); + })); + } + + public void testMultipleValidResponse() { + processor.next(result(""" + {"value": 1} + {"value": 2} + {"value": 3} + """)); + verify(upstream, times(0)).request(1); + verify(downstream, times(1)).onNext(assertArg(deque -> { + assertThat(deque, notNullValue()); + assertThat(deque.size(), is(3)); + var items = deque.iterator(); + IntStream.range(1, 4).forEach(i -> { + assertThat(items.hasNext(), is(true)); + assertThat(items.next(), containsString(String.valueOf(i))); + }); + })); + } + + public void testOnCompleteFlushesResponse() { + processor.next(result(""" + {"value": 1}""")); + + // onNext should not be called with only one value + verify(downstream, times(0)).onNext(any()); + verify(downstream, times(0)).onComplete(); + + // onComplete should flush the value pending, and onNext should be called + processor.onComplete(); + verify(downstream, times(1)).onNext(assertArg(deque -> { + assertThat(deque, notNullValue()); + assertThat(deque.size(), is(1)); + var item = deque.getFirst(); + assertThat(item, containsString(String.valueOf(1))); + })); + verify(downstream, times(0)).onComplete(); + + // next time the downstream requests data, onComplete is called + var downstreamSubscription = ArgumentCaptor.forClass(Flow.Subscription.class); + verify(downstream).onSubscribe(downstreamSubscription.capture()); + downstreamSubscription.getValue().request(1); + verify(downstream, times(1)).onComplete(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java index c33f30d461350..c0ef4e45f101f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/random/RandomRankRetrieverBuilderTests.java @@ -17,8 +17,6 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.json.JsonXContent; -import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder; -import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import java.io.IOException; import java.util.ArrayList; @@ -48,8 +46,8 @@ protected RandomRankRetrieverBuilder createTestInstance() { } @Override - protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) { - return RandomRankRetrieverBuilder.PARSER.apply( + protected RandomRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (RandomRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), @@ -77,8 +75,8 @@ protected NamedXContentRegistry xContentRegistry() { entries.add( new NamedXContentRegistry.Entry( RetrieverBuilder.class, - new ParseField(TextSimilarityRankBuilder.NAME), - (p, c) -> TextSimilarityRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) + new ParseField(RandomRankBuilder.NAME), + (p, c) -> RandomRankRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) ) ); return new NamedXContentRegistry(entries); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java new file mode 100644 index 0000000000000..fed4565c54bd4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.InferencePlugin; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.search.rank.RankDoc.NO_RANK; + +public class TextSimilarityRankDocTests extends AbstractRankDocWireSerializingTestCase { + + static TextSimilarityRankDoc createTestTextSimilarityRankDoc() { + TextSimilarityRankDoc instance = new TextSimilarityRankDoc( + randomNonNegativeInt(), + randomFloat(), + randomBoolean() ? -1 : randomNonNegativeInt(), + randomAlphaOfLength(randomIntBetween(2, 5)), + randomAlphaOfLength(randomIntBetween(2, 5)) + ); + instance.rank = randomBoolean() ? NO_RANK : randomIntBetween(1, 10000); + return instance; + } + + @Override + protected List getAdditionalNamedWriteables() { + try (InferencePlugin plugin = new InferencePlugin(Settings.EMPTY)) { + return plugin.getNamedWriteables(); + } + } + + @Override + protected Writeable.Reader instanceReader() { + return TextSimilarityRankDoc::new; + } + + @Override + protected TextSimilarityRankDoc createTestRankDoc() { + return createTestTextSimilarityRankDoc(); + } + + @Override + protected TextSimilarityRankDoc mutateInstance(TextSimilarityRankDoc instance) throws IOException { + int doc = instance.doc; + int shardIndex = instance.shardIndex; + float score = instance.score; + int rank = instance.rank; + String inferenceId = instance.inferenceId; + String field = instance.field; + + switch (randomInt(5)) { + case 0: + doc = randomValueOtherThan(doc, ESTestCase::randomNonNegativeInt); + break; + case 1: + shardIndex = shardIndex == -1 ? randomNonNegativeInt() : -1; + break; + case 2: + score = randomValueOtherThan(score, ESTestCase::randomFloat); + break; + case 3: + rank = rank == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK; + break; + case 4: + inferenceId = randomValueOtherThan(inferenceId, () -> randomAlphaOfLength(randomIntBetween(2, 5))); + break; + case 5: + field = randomValueOtherThan(field, () -> randomAlphaOfLength(randomIntBetween(2, 5))); + break; + default: + throw new AssertionError(); + } + TextSimilarityRankDoc mutated = new TextSimilarityRankDoc(doc, score, shardIndex, inferenceId, field); + mutated.rank = rank; + return mutated; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 1a72cb0da2899..32301bf9efea9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -8,23 +8,18 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.common.Strings; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryRewriteContext; -import org.elasticsearch.index.query.RandomQueryBuilder; -import org.elasticsearch.index.query.RangeQueryBuilder; -import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.builder.SubSearchSourceBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.usage.SearchUsageHolder; +import org.elasticsearch.usage.UsageService; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; @@ -35,10 +30,8 @@ import java.util.List; import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; -import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.Mockito.mock; public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTestCase { @@ -72,13 +65,14 @@ protected TextSimilarityRankRetrieverBuilder createTestInstance() { } @Override - protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) { - return TextSimilarityRankRetrieverBuilder.PARSER.apply( + protected TextSimilarityRankRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (TextSimilarityRankRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( parser, new RetrieverParserContext( new SearchUsage(), nf -> nf == RetrieverBuilder.RETRIEVERS_SUPPORTED || nf == TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED + || nf == TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED ) ); } @@ -128,107 +122,43 @@ public void testParserDefaults() throws IOException { } } - public void testRewriteInnerRetriever() throws IOException { - final boolean[] rewritten = { false }; - List preFilterQueryBuilders = new ArrayList<>(); - if (randomBoolean()) { - for (int i = 0; i < randomIntBetween(1, 5); i++) { - preFilterQueryBuilders.add(RandomQueryBuilder.createQuery(random())); + public void testTextSimilarityRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"text_similarity_reranker\": {" + + " \"retriever\": {" + + " \"test\": {" + + " \"value\": \"my-test-retriever\"" + + " }" + + " }," + + " \"field\": \"my-field\"," + + " \"inference_id\": \"my-inference-id\"," + + " \"inference_text\": \"my-inference-text\"," + + " \"rank_window_size\": 100," + + " \"min_score\": 20.0," + + " \"_name\": \"foo_reranker\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class)); + TextSimilarityRankRetrieverBuilder parsed = (TextSimilarityRankRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(20f)); + assertThat(parsed.retrieverName(), equalTo("foo_reranker")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(TextSimilarityRankRetrieverBuilder.class)); + TextSimilarityRankRetrieverBuilder deserialized = (TextSimilarityRankRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); } } - RetrieverBuilder innerRetriever = new TestRetrieverBuilder("top-level-retriever") { - @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (randomBoolean()) { - return this; - } - rewritten[0] = true; - return new TestRetrieverBuilder("nested-rewritten-retriever") { - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (preFilterQueryBuilders.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - - for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) { - boolQueryBuilder.filter(preFilterQueryBuilder); - } - boolQueryBuilder.must(new RangeQueryBuilder("some_field")); - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder)); - } else { - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(new RangeQueryBuilder("some_field"))); - } - } - }; - } - - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - if (preFilterQueryBuilders.isEmpty() == false) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - - for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) { - boolQueryBuilder.filter(preFilterQueryBuilder); - } - boolQueryBuilder.must(new TermQueryBuilder("field", "value")); - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder)); - } else { - searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(new TermQueryBuilder("field", "value"))); - } - } - }; - TextSimilarityRankRetrieverBuilder textSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder( - innerRetriever - ); - textSimilarityRankRetrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); - SearchSourceBuilder source = new SearchSourceBuilder().retriever(textSimilarityRankRetrieverBuilder); - QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); - source = Rewriteable.rewrite(source, queryRewriteContext); - assertNull(source.retriever()); - if (false == preFilterQueryBuilders.isEmpty()) { - if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) { - assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); - BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); - assertFalse(bq.must().isEmpty()); - assertThat(bq.must().size(), equalTo(1)); - if (rewritten[0]) { - assertThat(bq.must().get(0), instanceOf(RangeQueryBuilder.class)); - } else { - assertThat(bq.must().get(0), instanceOf(TermQueryBuilder.class)); - } - for (int j = 0; j < bq.filter().size(); j++) { - assertEqualQueryOrMatchAllNone(bq.filter().get(j), preFilterQueryBuilders.get(j)); - } - } - } else { - if (rewritten[0]) { - assertThat(source.query(), instanceOf(RangeQueryBuilder.class)); - } else { - assertThat(source.query(), instanceOf(TermQueryBuilder.class)); - } - } - } - - public void testIsCompound() { - RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) { - @Override - public boolean isCompound() { - return true; - } - }; - RetrieverBuilder nonCompoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) { - @Override - public boolean isCompound() { - return false; - } - }; - TextSimilarityRankRetrieverBuilder compoundTextSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder( - compoundInnerRetriever - ); - assertTrue(compoundTextSimilarityRankRetrieverBuilder.isCompound()); - TextSimilarityRankRetrieverBuilder nonCompoundTextSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder( - nonCompoundInnerRetriever - ); - assertFalse(nonCompoundTextSimilarityRankRetrieverBuilder.isCompound()); } public void testTopDocsQuery() { @@ -239,11 +169,6 @@ public QueryBuilder topDocsQuery() { } }; TextSimilarityRankRetrieverBuilder retriever = createRandomTextSimilarityRankRetrieverBuilder(innerRetriever); - assertThat(retriever.topDocsQuery(), instanceOf(TermQueryBuilder.class)); + expectThrows(IllegalStateException.class, "Should not be called, missing a rewrite?", retriever::topDocsQuery); } - - private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) { - assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected))); - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 22503108b5262..420a635963a29 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -38,6 +38,8 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; @@ -1349,6 +1351,54 @@ public void testDefaultSimilarity() { assertEquals(SimilarityMeasure.DOT_PRODUCT, CohereService.defaultSimilarity()); } + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + {"event_type":"text-generation", "text":"hello"} + {"event_type":"text-generation", "text":"there"} + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello"},{"delta":"there"}]}"""); + } + + private InferenceServiceResults streamChatCompletion() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return listener.actionGet(TIMEOUT); + } + } + + public void testInfer_StreamRequest_ErrorResponse() throws Exception { + String responseJson = """ + { "event_type":"stream-end", "finish_reason":"ERROR", "response":{ "text": "how dare you" } } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var result = streamChatCompletion(); + + InferenceEventsAssertion.assertThat(result) + .hasFinishedStream() + .hasNoEvents() + .hasErrorWithStatusCode(500) + .hasErrorContaining("how dare you"); + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index cd6da4c0ad8d8..db7189dc1af17 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.core.action.util.QueryPage; @@ -44,15 +43,14 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.core.utils.FloatConversionUtils; import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.junit.After; import org.junit.Before; @@ -663,14 +661,12 @@ public void testParsePersistedConfig() { @SuppressWarnings("unchecked") public void testChunkInfer_e5() { var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(MlChunkedTextEmbeddingFloatResultsTests.createRandomResults()); - mlTrainedModelResults.add(MlChunkedTextEmbeddingFloatResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); + mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - ThreadPool threadpool = new TestThreadPool("test"); Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); + when(client.threadPool()).thenReturn(threadPool); doAnswer(invocationOnMock -> { var listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onResponse(response); @@ -687,47 +683,26 @@ public void testChunkInfer_e5() { var gotResults = new AtomicBoolean(); var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); + assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); var result1 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0); - assertEquals( - ((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(0)).getChunks().size(), - result1.getChunks().size() - ); - assertEquals( - ((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(0)).getChunks().get(0).matchedText(), - result1.getChunks().get(0).matchedText() - ); + assertThat(result1.chunks(), hasSize(1)); assertArrayEquals( - (FloatConversionUtils.floatArrayOf( - ((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(0)).getChunks().get(0).embedding() - )), + ((MlTextEmbeddingResults) mlTrainedModelResults.get(0)).getInferenceAsFloat(), result1.getChunks().get(0).embedding(), 0.0001f ); + assertEquals("foo", result1.getChunks().get(0).matchedText()); assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); var result2 = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(1); - // assertEquals(((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunks()); - - assertEquals( - ((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(1)).getChunks().size(), - result2.getChunks().size() - ); - assertEquals( - ((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(1)).getChunks().get(0).matchedText(), - result2.getChunks().get(0).matchedText() - ); + assertThat(result2.chunks(), hasSize(1)); assertArrayEquals( - (FloatConversionUtils.floatArrayOf( - ((MlChunkedTextEmbeddingFloatResults) mlTrainedModelResults.get(1)).getChunks().get(0).embedding() - )), + ((MlTextEmbeddingResults) mlTrainedModelResults.get(1)).getInferenceAsFloat(), result2.getChunks().get(0).embedding(), 0.0001f ); + assertEquals("bar", result2.getChunks().get(0).matchedText()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); gotResults.set(true); }, ESTestCase::fail); @@ -739,26 +714,21 @@ public void testChunkInfer_e5() { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) + ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) ); - if (gotResults.get() == false) { - terminate(threadpool); - } assertTrue("Listener not called", gotResults.get()); } @SuppressWarnings("unchecked") public void testChunkInfer_Sparse() { var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); + mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults()); var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - ThreadPool threadpool = new TestThreadPool("test"); Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); + when(client.threadPool()).thenReturn(threadPool); doAnswer(invocationOnMock -> { var listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onResponse(response); @@ -775,16 +745,21 @@ public void testChunkInfer_Sparse() { var gotResults = new AtomicBoolean(); var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); + assertThat(chunkedResponse, hasSize(2)); assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); + assertEquals( + ((TextExpansionResults) mlTrainedModelResults.get(0)).getWeightedTokens(), + result1.getChunkedResults().get(0).weightedTokens() + ); + assertEquals("foo", result1.getChunkedResults().get(0).matchedText()); assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); + assertEquals( + ((TextExpansionResults) mlTrainedModelResults.get(1)).getWeightedTokens(), + result2.getChunkedResults().get(0).weightedTokens() + ); + assertEquals("bar", result2.getChunkedResults().get(0).matchedText()); gotResults.set(true); }, ESTestCase::fail); @@ -796,12 +771,9 @@ public void testChunkInfer_Sparse() { InputType.SEARCH, new ChunkingOptions(null, null), InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) + ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) ); - if (gotResults.get() == false) { - terminate(threadpool); - } assertTrue("Listener not called", gotResults.get()); } @@ -811,57 +783,103 @@ public void testChunkInferSetsTokenization() { var expectedWindowSize = new AtomicReference(); Client client = mock(Client.class); - ThreadPool threadpool = new TestThreadPool("test"); - try { - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); - var update = (TokenizationConfigUpdate) request.getUpdate(); - assertEquals(update.getSpanSettings().span(), expectedSpan.get()); - assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); - return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); - - var model = new MultilingualE5SmallModel( - "foo", - TaskType.TEXT_EMBEDDING, - "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; + assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); + var update = (TokenizationConfigUpdate) request.getUpdate(); + assertEquals(update.getSpanSettings().span(), expectedSpan.get()); + assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); + return null; + }).when(client) + .execute( + same(InferTrainedModelDeploymentAction.INSTANCE), + any(InferTrainedModelDeploymentAction.Request.class), + any(ActionListener.class) ); - var service = createService(client); - expectedSpan.set(-1); - expectedWindowSize.set(null); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) + ); + var service = createService(client); + + expectedSpan.set(-1); + expectedWindowSize.set(null); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + + expectedSpan.set(-1); + expectedWindowSize.set(256); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(256, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); - expectedSpan.set(-1); - expectedWindowSize.set(256); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(256, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - } finally { - terminate(threadpool); - } + } + + @SuppressWarnings("unchecked") + public void testChunkInfer_FailsBatch() { + var mlTrainedModelResults = new ArrayList(); + mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults()); + mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) + ); + var service = createService(client); + + var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { + assertThat(chunkedResponse, hasSize(3)); + // a single failure fails the batch + for (var er : chunkedResponse) { + assertThat(er, instanceOf(ErrorChunkedInferenceResults.class)); + assertEquals("boom", ((ErrorChunkedInferenceResults) er).getException().getMessage()); + } + + gotResults.set(true); + }, ESTestCase::fail); + + service.chunkedInfer( + model, + null, + List.of("foo", "bar", "baz"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(resultsListener, () -> terminate(threadPool)) + ); + + assertTrue("Listener not called", gotResults.get()); } public void testParsePersistedConfig_Rerank() { @@ -992,14 +1010,12 @@ public void testBuildInferenceRequest() { var inputs = randomList(1, 3, () -> randomAlphaOfLength(4)); var inputType = randomFrom(InputType.SEARCH, InputType.INGEST); var timeout = randomTimeValue(); - var chunk = randomBoolean(); var request = ElasticsearchInternalService.buildInferenceRequest( id, TextEmbeddingConfigUpdate.EMPTY_INSTANCE, inputs, inputType, - timeout, - chunk + timeout ); assertEquals(id, request.getId()); @@ -1009,7 +1025,7 @@ public void testBuildInferenceRequest() { request.getPrefixType() ); assertEquals(timeout, request.getInferenceTimeout()); - assertEquals(chunk, request.isChunked()); + assertEquals(false, request.isChunked()); } @SuppressWarnings("unchecked") @@ -1132,6 +1148,32 @@ public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() } } + public void testEmbeddingTypeFromTaskTypeAndSettings() { + assertEquals( + EmbeddingRequestChunker.EmbeddingType.SPARSE, + ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( + TaskType.SPARSE_EMBEDDING, + new ElasticsearchInternalServiceSettings(1, 1, "foo", null) + ) + ); + assertEquals( + EmbeddingRequestChunker.EmbeddingType.FLOAT, + ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( + TaskType.TEXT_EMBEDDING, + new MultilingualE5SmallInternalServiceSettings(1, 1, "foo", null) + ) + ); + + var e = expectThrows( + ElasticsearchStatusException.class, + () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( + TaskType.COMPLETION, + new ElasticsearchInternalServiceSettings(1, 1, "foo", null) + ) + ); + assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]")); + } + private ElasticsearchInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); return new ElasticsearchInternalService(context); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml index e2c1417057578..9a4d7f4416164 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml @@ -87,11 +87,9 @@ setup: - length: { hits.hits: 2 } - match: { hits.hits.0._id: "doc_2" } - - match: { hits.hits.0._rank: 1 } - close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } } - match: { hits.hits.1._id: "doc_1" } - - match: { hits.hits.1._rank: 2 } - close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } } --- @@ -123,7 +121,6 @@ setup: - length: { hits.hits: 1 } - match: { hits.hits.0._id: "doc_1" } - - match: { hits.hits.0._rank: 1 } - close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } } @@ -178,3 +175,40 @@ setup: field: text size: 10 + +--- +"text similarity reranking with explain": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + term: { + topic: "science" + } + } + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + size: 10 + explain: true + + - match: { hits.hits.0._id: "doc_2" } + - match: { hits.hits.1._id: "doc_1" } + + - close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } } + - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" } + - match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" } diff --git a/x-pack/plugin/rank-rrf/build.gradle b/x-pack/plugin/rank-rrf/build.gradle index 2db33fa0f2c8d..2c3f217243aa4 100644 --- a/x-pack/plugin/rank-rrf/build.gradle +++ b/x-pack/plugin/rank-rrf/build.gradle @@ -20,7 +20,11 @@ dependencies { compileOnly project(path: xpackModule('core')) testImplementation(testArtifact(project(xpackModule('core')))) + testImplementation(testArtifact(project(':server'))) clusterModules project(xpackModule('rank-rrf')) + clusterModules project(xpackModule('inference')) clusterModules project(':modules:lang-painless') + + clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin') } diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 2e7bc44811bf6..be64d34dc8765 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -33,7 +33,6 @@ import java.util.Collection; import java.util.List; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -98,7 +97,7 @@ protected void setupIndex() { } } """; - createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).build()); + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build()); admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term"); indexDoc( @@ -167,8 +166,8 @@ public void testRRFPagination() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -214,8 +213,8 @@ public void testRRFWithAggs() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -266,8 +265,8 @@ public void testRRFWithCollapse() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -320,8 +319,8 @@ public void testRRFRetrieverWithCollapseAndAggs() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -383,8 +382,8 @@ public void testMultipleRRFRetrievers() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -446,8 +445,8 @@ public void testRRFExplainWithNamedRetrievers() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); source.retriever( new RRFRetrieverBuilder( Arrays.asList( @@ -474,13 +473,12 @@ public void testRRFExplainWithNamedRetrievers() { assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; assertThat(rrfDetails.getDetails().length, equalTo(3)); - assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 2]")); + assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 1]")); - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [1] in query at index [1]")); - assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [2] in query at index [2]")); + assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [1] in query at index [2]")); }); } @@ -503,8 +501,8 @@ public void testRRFExplainWithAnotherNestedRRF() { QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2", "doc_3", "doc_6")).boost(20L) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 3, 2, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 4.0f }, null, 10, 100, null); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null); RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder( Arrays.asList( diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java index 512874e5009f3..ea251917cfae2 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderNestedDocsIT.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; @@ -21,8 +22,9 @@ import java.util.Arrays; -import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; public class RRFRetrieverBuilderNestedDocsIT extends RRFRetrieverBuilderIT { @@ -68,7 +70,7 @@ protected void setupIndex() { } } """; - createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, 1).put(SETTING_NUMBER_OF_REPLICAS, 0).build()); + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build()); admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term", LAST_30D_FIELD, 100); indexDoc( @@ -134,9 +136,9 @@ public void testRRFRetrieverWithNestedQuery() { final int rankWindowSize = 100; final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); - // this one retrieves docs 1, 4 + // this one retrieves docs 1 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( - QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gte(30L), ScoreMode.Avg) + QueryBuilders.nestedQuery("views", QueryBuilders.rangeQuery(LAST_30D_FIELD).gte(50L), ScoreMode.Avg) ); // this one retrieves docs 2 and 6 due to prefilter StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( @@ -157,16 +159,21 @@ public void testRRFRetrieverWithNestedQuery() { ) ); source.fetchField(TOPIC_FIELD); + source.explain(true); SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); ElasticsearchAssertions.assertResponse(req, resp -> { assertNull(resp.pointInTimeId()); assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value, equalTo(4L)); + assertThat(resp.getHits().getTotalHits().value, equalTo(3L)); assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_4")); + assertThat((double) resp.getHits().getAt(0).getScore(), closeTo(0.1742, 1e-4)); + assertThat( + Arrays.stream(resp.getHits().getHits()).skip(1).map(SearchHit::getId).toList(), + containsInAnyOrder("doc_1", "doc_2") + ); + assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(0.0909, 1e-4)); + assertThat((double) resp.getHits().getAt(2).getScore(), closeTo(0.0909, 1e-4)); }); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java index 500ed17395127..272df248e53e9 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.rank.rrf; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -169,4 +170,9 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc builder.field("scores", scores); builder.field("rankConstant", rankConstant); } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.RRF_QUERY_REWRITE; + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 496af99574431..5f19e361d857d 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -180,10 +180,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.startArray(RETRIEVERS_FIELD.getPreferredName()); for (var entry : innerRetrievers) { - builder.startObject(); - builder.field(entry.retriever().getName()); entry.retriever().toXContent(builder, params); - builder.endObject(); } builder.endArray(); } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java index 4b64b6c173c92..5548392270a08 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java @@ -7,15 +7,17 @@ package org.elasticsearch.xpack.rank.rrf; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable.Reader; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.List; import static org.elasticsearch.xpack.rank.rrf.RRFRankDoc.NO_RANK; -public class RRFRankDocTests extends AbstractWireSerializingTestCase { +public class RRFRankDocTests extends AbstractRankDocWireSerializingTestCase { static RRFRankDoc createTestRRFRankDoc(int queryCount) { RRFRankDoc instance = new RRFRankDoc( @@ -35,9 +37,13 @@ static RRFRankDoc createTestRRFRankDoc(int queryCount) { return instance; } - static RRFRankDoc createTestRRFRankDoc() { - int queryCount = randomIntBetween(2, 20); - return createTestRRFRankDoc(queryCount); + @Override + protected List getAdditionalNamedWriteables() { + try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) { + return rrfRankPlugin.getNamedWriteables(); + } catch (IOException ex) { + throw new AssertionError("Failed to create RRFRankPlugin", ex); + } } @Override @@ -46,8 +52,9 @@ protected Reader instanceReader() { } @Override - protected RRFRankDoc createTestInstance() { - return createTestRRFRankDoc(); + protected RRFRankDoc createTestRankDoc() { + int queryCount = randomIntBetween(2, 20); + return createTestRRFRankDoc(queryCount); } @Override diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index e360237371a82..d324effe41c22 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -8,19 +8,27 @@ package org.elasticsearch.xpack.rank.rrf; import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.usage.SearchUsageHolder; +import org.elasticsearch.usage.UsageService; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase { /** @@ -53,7 +61,10 @@ protected RRFRetrieverBuilder createTestInstance() { @Override protected RRFRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { - return RRFRetrieverBuilder.PARSER.apply(parser, new RetrieverParserContext(new SearchUsage(), nf -> true)); + return (RRFRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), nf -> true) + ); } @Override @@ -81,4 +92,48 @@ protected NamedXContentRegistry xContentRegistry() { ); return new NamedXContentRegistry(entries); } + + public void testRRFRetrieverParsing() throws IOException { + String restContent = "{" + + " \"retriever\": {" + + " \"rrf\": {" + + " \"retrievers\": [" + + " {" + + " \"test\": {" + + " \"value\": \"foo\"" + + " }" + + " }," + + " {" + + " \"test\": {" + + " \"value\": \"bar\"" + + " }" + + " }" + + " ]," + + " \"rank_window_size\": 100," + + " \"rank_constant\": 10," + + " \"min_score\": 20.0," + + " \"_name\": \"foo_rrf\"" + + " }" + + " }" + + "}"; + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(20f)); + assertThat(parsed.retrieverName(), equalTo("foo_rrf")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java index 3a577eb62faa3..32b5aedd5d99a 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java @@ -23,7 +23,9 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { .nodes(2) .module("rank-rrf") .module("lang-painless") + .module("x-pack-inference") .setting("xpack.license.self_generated.type", "trial") + .plugin("inference-service-test") .build(); public RRFRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml index 47ba3658bb38d..d5d7a5de1dc71 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/350_rrf_retriever_pagination.yml @@ -1,6 +1,8 @@ setup: - skip: - features: close_to + features: + - close_to + - contains - requires: cluster_features: 'rrf_retriever_composition_supported' @@ -10,8 +12,6 @@ setup: indices.create: index: test body: - settings: - number_of_shards: 1 mappings: properties: number_val: @@ -81,35 +81,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -124,35 +138,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "A", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "B", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "C", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "D", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 7.0 } } ] @@ -198,35 +228,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -241,35 +285,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "A", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "B", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "C", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "D", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 7.0 } } ] @@ -306,35 +366,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -349,35 +423,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "A", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "B", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "C", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "D", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 7.0 } } ] @@ -422,35 +512,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -465,35 +569,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -533,35 +653,49 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "1", - boost: 10.0 - } - } - }, - { - term: { - number_val: { - value: "2", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 10.0 } - }, - { - term: { - number_val: { - value: "3", - boost: 8.0 - } + },{ + constant_score: { + filter: { + term: { + number_val: { + value: "2" + } + } + }, + boost: 9.0 + } }, + { + constant_score: { + filter: { + term: { + number_val: { + value: "3" + } + } + }, + boost: 8.0 } }, { - term: { - number_val: { - value: "4", - boost: 7.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "4" + } + } + }, + boost: 7.0 } } ] @@ -576,35 +710,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -632,9 +782,9 @@ setup: "Pagination within interleaved results, different result set sizes, rank_window_size covering all results": # perform multiple searches with different "from" parameter, ensuring that results are consistent # rank_window_size covers the entire result set for both queries, so pagination should be consistent - # queryA has a result set of [5, 1] and + # queryA has a result set of [1] and # queryB has a result set of [4, 3, 1, 2] - # so for rank_constant=10, the expected order is [1, 4, 5, 3, 2] + # so for rank_constant=10, the expected order is [1, 4, 3, 2] - do: search: index: test @@ -645,19 +795,11 @@ setup: { retrievers: [ { - # this should clause would generate the result set [5, 1] + # this should clause would generate the result set [1] standard: { query: { bool: { should: [ - { - term: { - number_val: { - value: "5", - boost: 10.0 - } - } - }, { term: { number_val: { @@ -678,35 +820,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -721,11 +879,11 @@ setup: from : 0 size : 2 - - match: { hits.total.value : 5 } + - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - match: { hits.hits.0._id: "1" } # score for doc 1 is (1/12 + 1/13) - - close_to: {hits.hits.0._score: {value: 0.1602, error: 0.001}} + - close_to: {hits.hits.0._score: {value: 0.1678, error: 0.001}} - match: { hits.hits.1._id: "4" } # score for doc 4 is (1/11) - close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}} @@ -740,19 +898,11 @@ setup: { retrievers: [ { - # this should clause would generate the result set [5, 1] + # this should clause would generate the result set [1] standard: { query: { bool: { should: [ - { - term: { - number_val: { - value: "5", - boost: 10.0 - } - } - }, { term: { number_val: { @@ -773,35 +923,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -816,14 +982,15 @@ setup: from : 2 size : 2 - - match: { hits.total.value : 5 } + - match: { hits.total.value : 4 } - length: { hits.hits : 2 } - - match: { hits.hits.0._id: "5" } - # score for doc 5 is (1/11) - - close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}} - - match: { hits.hits.1._id: "3" } + - match: { hits.hits.0._id: "3" } # score for doc 3 is (1/12) - - close_to: {hits.hits.1._score: {value: 0.0833, error: 0.001}} + - close_to: {hits.hits.0._score: {value: 0.0833, error: 0.001}} + - match: { hits.hits.1._id: "2" } + # score for doc 2 is (1/14) + - close_to: {hits.hits.1._score: {value: 0.0714, error: 0.001}} + - do: search: @@ -835,19 +1002,11 @@ setup: { retrievers: [ { - # this should clause would generate the result set [5, 1] + # this should clause would generate the result set [1] standard: { query: { bool: { should: [ - { - term: { - number_val: { - value: "5", - boost: 10.0 - } - } - }, { term: { number_val: { @@ -868,35 +1027,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -911,12 +1086,8 @@ setup: from: 4 size: 2 - - match: { hits.total.value: 5 } - - length: { hits.hits: 1 } - - match: { hits.hits.0._id: "2" } - # score for doc 2 is (1/14) - - close_to: {hits.hits.0._score: {value: 0.0714, error: 0.001}} - + - match: { hits.total.value: 4 } + - length: { hits.hits: 0 } --- "Pagination within interleaved results, different result set sizes, rank_window_size not covering all results": @@ -943,19 +1114,27 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "5", - boost: 10.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "5" + } + } + }, + boost: 10.0 } }, { - term: { - number_val: { - value: "1", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 9.0 } } ] @@ -970,35 +1149,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] @@ -1015,11 +1210,11 @@ setup: - match: { hits.total.value : 5 } - length: { hits.hits : 2 } - - match: { hits.hits.0._id: "4" } - # score for doc 4 is (1/11) + - contains: { hits.hits: { _id: "4" } } + - contains: { hits.hits: { _id: "5" } } + + # both docs have the same score (1/11) - close_to: {hits.hits.0._score: {value: 0.0909, error: 0.001}} - - match: { hits.hits.1._id: "5" } - # score for doc 5 is (1/11) - close_to: {hits.hits.1._score: {value: 0.0909, error: 0.001}} - do: @@ -1038,19 +1233,27 @@ setup: bool: { should: [ { - term: { - number_val: { - value: "5", - boost: 10.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "5" + } + } + }, + boost: 10.0 } }, { - term: { - number_val: { - value: "1", - boost: 9.0 - } + constant_score: { + filter: { + term: { + number_val: { + value: "1" + } + } + }, + boost: 9.0 } } ] @@ -1065,35 +1268,51 @@ setup: bool: { should: [ { - term: { - char_val: { - value: "D", - boost: 10.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "D" + } + } + }, + boost: 10.0 } }, { - term: { - char_val: { - value: "C", - boost: 9.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "C" + } + } + }, + boost: 9.0 } }, { - term: { - char_val: { - value: "A", - boost: 8.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "A" + } + } + }, + boost: 8.0 } }, { - term: { - char_val: { - value: "B", - boost: 7.0 - } + constant_score: { + filter: { + term: { + char_val: { + value: "B" + } + } + }, + boost: 7.0 } } ] diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml index 1f7125377b892..517c162c33e95 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml @@ -1,7 +1,6 @@ setup: - skip: features: close_to - - requires: cluster_features: 'rrf_retriever_composition_supported' reason: 'test requires rrf retriever composition support' @@ -10,8 +9,6 @@ setup: indices.create: index: test body: - settings: - number_of_shards: 1 mappings: properties: text: @@ -42,7 +39,7 @@ setup: index: test id: "1" body: - text: "term term term term term term term term term" + text: "term1" vector: [1.0] - do: @@ -50,7 +47,7 @@ setup: index: test id: "2" body: - text: "term term term term term term term term" + text: "term2" text_to_highlight: "search for the truth" keyword: "biology" vector: [2.0] @@ -60,8 +57,8 @@ setup: index: test id: "3" body: - text: "term term term term term term term" - text_to_highlight: "nothing related but still a match" + text: "term3" + text_to_highlight: "nothing related" keyword: "technology" vector: [3.0] @@ -70,14 +67,14 @@ setup: index: test id: "4" body: - text: "term term term term term term" + text: "term4" vector: [4.0] - do: index: index: test id: "5" body: - text: "term term term term term" + text: "term5" text_to_highlight: "You know, for Search!" keyword: "technology" integer: 5 @@ -87,7 +84,7 @@ setup: index: test id: "6" body: - text: "term term term term" + text: "term6" keyword: "biology" integer: 6 vector: [6.0] @@ -96,27 +93,26 @@ setup: index: test id: "7" body: - text: "term term term" + text: "term7" keyword: "astronomy" - vector: [7.0] + vector: [77.0] nested: { views: 50} - do: index: index: test id: "8" body: - text: "term term" + text: "term8" keyword: "technology" - vector: [8.0] nested: { views: 100} - do: index: index: test id: "9" body: - text: "term" + text: "term9" + integer: 2 keyword: "technology" - vector: [9.0] nested: { views: 10} - do: indices.refresh: {} @@ -133,6 +129,7 @@ setup: rrf: retrievers: [ { + # this one retrieves docs 6, 5, 4 knn: { field: vector, query_vector: [ 6.0 ], @@ -141,10 +138,72 @@ setup: } }, { + # this one retrieves docs 4, 5, 1, 2, 6 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + exists: { + field: text + } + }, + boost: 1 + } + } + ] } } } @@ -158,9 +217,13 @@ setup: terms: field: keyword - - match: { hits.hits.0._id: "5" } - - match: { hits.hits.1._id: "1" } + + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 0.1678, error: 0.001 } } + - match: { hits.hits.1._id: "5" } + - close_to: { hits.hits.1._score: { value: 0.1666, error: 0.001 } } - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.2._score: { value: 0.1575, error: 0.001 } } - match: { aggregations.keyword_aggs.buckets.0.key: "technology" } - match: { aggregations.keyword_aggs.buckets.0.doc_count: 4 } @@ -181,6 +244,7 @@ setup: rrf: retrievers: [ { + # this one retrieves docs 6, 5, 4 knn: { field: vector, query_vector: [ 6.0 ], @@ -189,10 +253,72 @@ setup: } }, { + # this one retrieves docs 4, 5, 1, 2, 6 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + exists: { + field: text + } + }, + boost: 1 + } + } + ] } } } @@ -208,12 +334,14 @@ setup: lang: painless source: "_score" - - - match: { hits.hits.0._id: "5" } - - match: { hits.hits.1._id: "1" } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 0.1678, error: 0.001 } } + - match: { hits.hits.1._id: "5" } + - close_to: { hits.hits.1._score: { value: 0.1666, error: 0.001 } } - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.2._score: { value: 0.1575, error: 0.001 } } - - close_to: { aggregations.max_score.value: { value: 0.15, error: 0.001 }} + - close_to: { aggregations.max_score.value: { value: 0.1678, error: 0.001 }} --- "rrf retriever with top-level collapse": @@ -228,6 +356,7 @@ setup: rrf: retrievers: [ { + # this one retrieves docs 6, 5, 4 knn: { field: vector, query_vector: [ 6.0 ], @@ -236,10 +365,72 @@ setup: } }, { + # this one retrieves docs 4, 5, 1, 2, 6 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + exists: { + field: text + } + }, + boost: 1 + } + } + ] } } } @@ -250,18 +441,23 @@ setup: size: 3 collapse: { field: keyword, inner_hits: { name: sub_hits, size: 2 } } - - match: { hits.hits.0._id: "5" } - - match: { hits.hits.1._id: "1" } + - match: { hits.total : 9 } + + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 0.1678, error: 0.001 } } + - match: { hits.hits.1._id: "5" } + - close_to: { hits.hits.1._score: { value: 0.1666, error: 0.001 } } - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.2._score: { value: 0.1575, error: 0.001 } } - - match: { hits.hits.0.inner_hits.sub_hits.hits.total : 4 } - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } - - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "5" } - - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "3" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "4" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "1" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.total : 4 } - length: { hits.hits.1.inner_hits.sub_hits.hits.hits : 2 } - - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.0._id: "1" } - - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.1._id: "4" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.0._id: "5" } + - match: { hits.hits.1.inner_hits.sub_hits.hits.hits.1._id: "3" } - length: { hits.hits.2.inner_hits.sub_hits.hits.hits: 2 } - match: { hits.hits.2.inner_hits.sub_hits.hits.hits.0._id: "6" } @@ -280,18 +476,132 @@ setup: rrf: retrievers: [ { - knn: { - field: vector, - query_vector: [ 6.0 ], - k: 3, - num_candidates: 10 + # this one retrieves docs 7, 3 + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term7 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term3 + } }, + boost: 9.0 + } + } + ] + } + } } }, { + # this one retrieves docs 1, 2, 3, 7 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term3 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 5.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term7 + } + }, + boost: 4.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term8 + } + }, + boost: 3.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term9 + } + }, + boost: 2.0 + } + } + ] } }, collapse: { field: keyword, inner_hits: { name: sub_hits, size: 1 } } @@ -303,8 +613,9 @@ setup: size: 3 - match: { hits.hits.0._id: "7" } - - match: { hits.hits.1._id: "1" } - - match: { hits.hits.2._id: "6" } + - close_to: { hits.hits.0._score: { value: 0.1623, error: 0.001 } } + - match: { hits.hits.1._id: "3" } + - close_to: { hits.hits.1._score: { value: 0.1602, error: 0.001 } } --- "rrf retriever highlighting results": @@ -331,7 +642,7 @@ setup: standard: { query: { term: { - keyword: technology + text: term5 } } } @@ -349,7 +660,7 @@ setup: } } - - match: { hits.total : 5 } + - match: { hits.total : 2 } - match: { hits.hits.0._id: "5" } - match: { hits.hits.0.highlight.text_to_highlight.0: "You know, for Search!" } @@ -357,9 +668,6 @@ setup: - match: { hits.hits.1._id: "2" } - match: { hits.hits.1.highlight.text_to_highlight.0: "search for the truth" } - - match: { hits.hits.2._id: "3" } - - not_exists: hits.hits.2.highlight - --- "rrf retriever with custom nested sort": @@ -374,12 +682,103 @@ setup: retrievers: [ { # this one retrievers docs 1, 2, 3, .., 9 - # but due to sorting, it will revert the order to 6, 5, .., 9 which due to + # but due to sorting, it will revert the order to 6, 5, 9, ... which due to # rank_window_size: 2 will only return 6 and 5 standard: { query: { - term: { - text: term + bool: { + should: [ + { + constant_score: { + filter: { + term: { + text: term1 + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term2 + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term3 + } + }, + boost: 8.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term4 + } + }, + boost: 7.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term5 + } + }, + boost: 6.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term6 + } + }, + boost: 5.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term7 + } + }, + boost: 4.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term8 + } + }, + boost: 3.0 + } + }, + { + constant_score: { + filter: { + term: { + text: term9 + } + }, + boost: 2.0 + } + } + ] } }, sort: [ @@ -410,7 +809,6 @@ setup: - length: {hits.hits: 2 } - match: { hits.hits.0._id: "6" } - - match: { hits.hits.1._id: "2" } --- "rrf retriever with nested query": @@ -427,7 +825,7 @@ setup: { knn: { field: vector, - query_vector: [ 7.0 ], + query_vector: [ 77.0 ], k: 1, num_candidates: 3 } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml new file mode 100644 index 0000000000000..3e758ae11f7e6 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml @@ -0,0 +1,334 @@ +setup: + - requires: + cluster_features: ['rrf_retriever_composition_supported', 'text_similarity_reranker_retriever_supported'] + reason: need to have support for rrf and semantic reranking composition + test_runner_features: "close_to" + + - do: + inference.put: + task_type: rerank + inference_id: my-rerank-model + body: > + { + "service": "test_reranking_service", + "service_settings": { + "model_id": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + + - do: + indices.create: + index: test-index + body: + settings: + number_of_shards: 1 + mappings: + properties: + text: + type: text + topic: + type: keyword + subtopic: + type: keyword + integer: + type: integer + + - do: + index: + index: test-index + id: doc_1 + body: + text: "Sun Moon Lake is a lake in Nantou County, Taiwan. It is the largest lake in Taiwan." + topic: [ "geography" ] + integer: 1 + + - do: + index: + index: test-index + id: doc_2 + body: + text: "The phases of the Moon come from the position of the Moon relative to the Earth and Sun." + topic: [ "science" ] + subtopic: [ "astronomy" ] + integer: 2 + + - do: + index: + index: test-index + id: doc_3 + body: + text: "As seen from Earth, a solar eclipse happens when the Moon is directly between the Earth and the Sun." + topic: [ "science" ] + subtopic: [ "technology" ] + integer: 3 + + - do: + indices.refresh: {} + +--- +"rrf retriever with a nested text similarity reranker": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 2 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + term: { + topic: "science" + } + } + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + size: 10 + from: 1 + aggs: + topics: + terms: + field: topic + size: 10 + + - match: { hits.total.value: 3 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_3" } + + - match: { aggregations.topics.buckets.0.key: "science" } + - match: { aggregations.topics.buckets.0.doc_count: 2 } + - match: { aggregations.topics.buckets.1.key: "geography" } + - match: { aggregations.topics.buckets.1.doc_count: 1 } + +--- +"Text similarity reranker on top of an RRF retriever": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + { + text_similarity_reranker: { + retriever: + { + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 3 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + standard: { + query: { + term: { + topic: "geography" + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + size: 10 + aggs: + topics: + terms: + field: topic + size: 10 + + - match: { hits.total.value: 2 } + - length: { hits.hits: 2 } + + - match: { hits.hits.0._id: "doc_3" } + - match: { hits.hits.1._id: "doc_1" } + + - match: { aggregations.topics.buckets.0.key: "geography" } + - match: { aggregations.topics.buckets.0.doc_count: 1 } + - match: { aggregations.topics.buckets.1.key: "science" } + - match: { aggregations.topics.buckets.1.doc_count: 1 } + + +--- +"explain using rrf retriever and text-similarity": + + - do: + search: + index: test-index + body: + track_total_hits: true + fields: [ "text", "topic" ] + retriever: + rrf: { + retrievers: + [ + { + standard: { + query: { + bool: { + should: + [ + { + constant_score: { + filter: { + term: { + integer: 1 + } + }, + boost: 10 + } + }, + { + constant_score: + { + filter: + { + term: + { + integer: 2 + } + }, + boost: 1 + } + } + ] + } + } + } + }, + { + text_similarity_reranker: { + retriever: + { + standard: { + query: { + term: { + topic: "science" + } + } + } + }, + rank_window_size: 10, + inference_id: my-rerank-model, + inference_text: "How often does the moon hide the sun?", + field: text + } + } + ], + rank_window_size: 10, + rank_constant: 1 + } + size: 10 + explain: true + + - match: { hits.hits.0._id: "doc_2" } + - match: { hits.hits.1._id: "doc_1" } + - match: { hits.hits.2._id: "doc_3" } + + - close_to: { hits.hits.0._explanation.value: { value: 0.6666667, error: 0.000001 } } + - match: {hits.hits.0._explanation.description: "/rrf.score:.\\[0.6666667\\].*/" } + - match: {hits.hits.0._explanation.details.0.value: 2} + - match: {hits.hits.0._explanation.details.0.description: "/rrf.score:.\\[0.33333334\\].*/" } + - match: {hits.hits.0._explanation.details.0.details.0.details.0.description: "/ConstantScore.*/" } + - match: {hits.hits.0._explanation.details.1.value: 2} + - match: {hits.hits.0._explanation.details.1.description: "/rrf.score:.\\[0.33333334\\].*/" } + - match: {hits.hits.0._explanation.details.1.details.0.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" } + - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/weight.*science.*/" }