diff --git a/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/GitInfoPlugin.java b/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/GitInfoPlugin.java index 8a1e3cabf890e..28b90714508bd 100644 --- a/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/GitInfoPlugin.java +++ b/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/GitInfoPlugin.java @@ -44,7 +44,7 @@ public void apply(Project project) { gitInfo.disallowChanges(); gitInfo.finalizeValueOnRead(); - revision = gitInfo.map(info -> info.getRevision() == null ? info.getRevision() : "master"); + revision = gitInfo.map(info -> info.getRevision() == null ? info.getRevision() : "main"); } public Property getGitInfo() { diff --git a/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/LicensingPlugin.java b/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/LicensingPlugin.java index 92ce2a3658a0e..ba170d083c886 100644 --- a/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/LicensingPlugin.java +++ b/build-conventions/src/main/java/org/elasticsearch/gradle/internal/conventions/LicensingPlugin.java @@ -21,6 +21,7 @@ public class LicensingPlugin implements Plugin { static final String ELASTIC_LICENSE_URL_PREFIX = "https://raw.githubusercontent.com/elastic/elasticsearch/"; static final String ELASTIC_LICENSE_URL_POSTFIX = "/licenses/ELASTIC-LICENSE-2.0.txt"; + static final String AGPL_ELASTIC_LICENSE_URL_POSTFIX = "/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt"; private ProviderFactory providerFactory; @@ -36,15 +37,18 @@ public void apply(Project project) { isSnapshotVersion(project) ? revision.get() : "v" + project.getVersion() ); - Provider projectLicenseURL = licenseCommitProvider.map(licenseCommit -> ELASTIC_LICENSE_URL_PREFIX + + Provider elasticLicenseURL = licenseCommitProvider.map(licenseCommit -> ELASTIC_LICENSE_URL_PREFIX + licenseCommit + ELASTIC_LICENSE_URL_POSTFIX); + Provider agplLicenseURL = licenseCommitProvider.map(licenseCommit -> ELASTIC_LICENSE_URL_PREFIX + + licenseCommit + AGPL_ELASTIC_LICENSE_URL_POSTFIX); // But stick the Elastic license url in project.ext so we can get it if we need to switch to it - project.getExtensions().getExtraProperties().set("elasticLicenseUrl", projectLicenseURL); + project.getExtensions().getExtraProperties().set("elasticLicenseUrl", elasticLicenseURL); MapProperty licensesProperty = project.getObjects().mapProperty(String.class, String.class).convention( providerFactory.provider(() -> Map.of( "Server Side Public License, v 1", "https://www.mongodb.com/licensing/server-side-public-license", - "Elastic License 2.0", projectLicenseURL.get()) + "Elastic License 2.0", elasticLicenseURL.get(), + "GNU Affero General Public License Version 3", agplLicenseURL.get()) ) ); diff --git a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy index e275a56682c01..c0b85ed7450f6 100644 --- a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy +++ b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/PublishPluginFuncTest.groovy @@ -74,6 +74,11 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { https://www.mongodb.com/licensing/server-side-public-license repo + + The OSI-approved Open Source license Version 3.0 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + repo + @@ -149,6 +154,11 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { https://www.mongodb.com/licensing/server-side-public-license repo + + The OSI-approved Open Source license Version 3.0 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + repo + @@ -233,6 +243,11 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { https://www.mongodb.com/licensing/server-side-public-license repo + + The OSI-approved Open Source license Version 3.0 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + repo + @@ -326,6 +341,11 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { https://www.mongodb.com/licensing/server-side-public-license repo + + The OSI-approved Open Source license Version 3.0 + https://raw.githubusercontent.com/elastic/elasticsearch/v1.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + repo + @@ -399,6 +419,11 @@ class PublishPluginFuncTest extends AbstractGradleFuncTest { https://www.mongodb.com/licensing/server-side-public-license repo + + The OSI-approved Open Source license Version 3.0 + https://raw.githubusercontent.com/elastic/elasticsearch/v2.0/licenses/AGPL-3.0+SSPL-1.0+ELASTIC-LICENSE-2.0.txt + repo + diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/DockerBase.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/DockerBase.java index ac83a01ffc294..95f279bfa5162 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/DockerBase.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/DockerBase.java @@ -31,7 +31,7 @@ public enum DockerBase { // Chainguard based wolfi image with latest jdk // This is usually updated via renovatebot // spotless:off - WOLFI("docker.elastic.co/wolfi/chainguard-base:latest@sha256:c16d3ad6cebf387e8dd2ad769f54320c4819fbbaa21e729fad087c7ae223b4d0", + WOLFI("docker.elastic.co/wolfi/chainguard-base:latest@sha256:90888b190da54062f67f3fef1372eb0ae7d81ea55f5a1f56d748b13e4853d984", "-wolfi", "apk" ), diff --git a/distribution/docker/src/docker/Dockerfile.ess b/distribution/docker/src/docker/Dockerfile.ess index 4a230bb562786..3ca5e8f2b42a3 100644 --- a/distribution/docker/src/docker/Dockerfile.ess +++ b/distribution/docker/src/docker/Dockerfile.ess @@ -25,7 +25,7 @@ USER root COPY plugins/*.zip /opt/plugins/archive/ -RUN chown root.root /opt/plugins/archive/* +RUN chown 1000:1000 /opt/plugins/archive/* RUN chmod 0444 /opt/plugins/archive/* FROM ${base_image} diff --git a/docs/changelog/114109.yaml b/docs/changelog/114109.yaml new file mode 100644 index 0000000000000..ce51ed50f724c --- /dev/null +++ b/docs/changelog/114109.yaml @@ -0,0 +1,5 @@ +pr: 114109 +summary: Update cluster stats for retrievers +area: Search +type: enhancement +issues: [] diff --git a/docs/changelog/114358.yaml b/docs/changelog/114358.yaml new file mode 100644 index 0000000000000..972bc5bfdbe1c --- /dev/null +++ b/docs/changelog/114358.yaml @@ -0,0 +1,5 @@ +pr: 114358 +summary: "ESQL: Use less memory in listener" +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/114363.yaml b/docs/changelog/114363.yaml new file mode 100644 index 0000000000000..51ca9ed34a7ca --- /dev/null +++ b/docs/changelog/114363.yaml @@ -0,0 +1,5 @@ +pr: 114363 +summary: Give the kibana system user permission to read security entities +area: Infra/Core +type: enhancement +issues: [] diff --git a/docs/changelog/114375.yaml b/docs/changelog/114375.yaml new file mode 100644 index 0000000000000..7ff7cc60b34ba --- /dev/null +++ b/docs/changelog/114375.yaml @@ -0,0 +1,5 @@ +pr: 114375 +summary: Handle `InternalSendException` inline for non-forking handlers +area: Distributed +type: bug +issues: [] diff --git a/docs/reference/cluster/stats.asciidoc b/docs/reference/cluster/stats.asciidoc index 5dd84abc96e1f..bd818a538f78b 100644 --- a/docs/reference/cluster/stats.asciidoc +++ b/docs/reference/cluster/stats.asciidoc @@ -762,6 +762,10 @@ Queries are counted once per search request, meaning that if the same query type (object) Search sections used in selected nodes. For each section, name and number of times it's been used is reported. +`retrievers`:: +(object) Retriever types that were used in selected nodes. +For each retriever, name and number of times it's been used is reported. + ===== `dense_vector`:: diff --git a/docs/reference/ingest/processors.asciidoc b/docs/reference/ingest/processors.asciidoc index 8622e0b98602c..8f7cef06d12a0 100644 --- a/docs/reference/ingest/processors.asciidoc +++ b/docs/reference/ingest/processors.asciidoc @@ -185,6 +185,9 @@ Executes another pipeline. <>:: Reroutes documents to another target index or data stream. +<>:: +Terminates the current ingest pipeline, causing no further processors to be run. + [discrete] [[ingest-process-category-array-json-handling]] === Array/JSON handling processors @@ -258,6 +261,7 @@ include::processors/set.asciidoc[] include::processors/set-security-user.asciidoc[] include::processors/sort.asciidoc[] include::processors/split.asciidoc[] +include::processors/terminate.asciidoc[] include::processors/trim.asciidoc[] include::processors/uppercase.asciidoc[] include::processors/url-decode.asciidoc[] diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataIndexTemplateServiceTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataIndexTemplateServiceTests.java index 199611d6b85ef..29e49c8ddfa17 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataIndexTemplateServiceTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataIndexTemplateServiceTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.ComponentTemplate; import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; -import org.elasticsearch.cluster.metadata.DataStreamFactoryRetention; import org.elasticsearch.cluster.metadata.DataStreamGlobalRetentionSettings; import org.elasticsearch.cluster.metadata.DataStreamLifecycle; import org.elasticsearch.cluster.metadata.MetadataCreateIndexService; @@ -217,10 +216,7 @@ private MetadataIndexTemplateService getMetadataIndexTemplateService() { xContentRegistry(), EmptySystemIndices.INSTANCE, indexSettingProviders, - DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() - ) + DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings()) ); } diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsActionTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsActionTests.java index a9ebd04d30f73..2efe881266c1b 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsActionTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/TransportGetDataStreamsActionTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.DataStream; -import org.elasticsearch.cluster.metadata.DataStreamFactoryRetention; import org.elasticsearch.cluster.metadata.DataStreamGlobalRetention; import org.elasticsearch.cluster.metadata.DataStreamGlobalRetentionSettings; import org.elasticsearch.cluster.metadata.DataStreamTestHelper; @@ -47,8 +46,7 @@ public class TransportGetDataStreamsActionTests extends ESTestCase { private final IndexNameExpressionResolver resolver = TestIndexNameExpressionResolver.newInstance(); private final SystemIndices systemIndices = new SystemIndices(List.of()); private final DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings = DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() + ClusterSettings.createBuiltInClusterSettings() ); public void testGetDataStream() { @@ -356,8 +354,7 @@ public void testPassingGlobalRetention() { ) .put(DataStreamGlobalRetentionSettings.DATA_STREAMS_MAX_RETENTION_SETTING.getKey(), globalRetention.maxRetention()) .build() - ), - DataStreamFactoryRetention.emptyFactoryRetention() + ) ); response = TransportGetDataStreamsAction.innerOperation( state, diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java index 05128e164e865..0d5ce54c44b56 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java @@ -37,7 +37,6 @@ import org.elasticsearch.cluster.block.ClusterBlock; import org.elasticsearch.cluster.block.ClusterBlocks; import org.elasticsearch.cluster.metadata.DataStream; -import org.elasticsearch.cluster.metadata.DataStreamFactoryRetention; import org.elasticsearch.cluster.metadata.DataStreamGlobalRetentionSettings; import org.elasticsearch.cluster.metadata.DataStreamLifecycle; import org.elasticsearch.cluster.metadata.DataStreamLifecycle.Downsampling; @@ -142,8 +141,7 @@ public class DataStreamLifecycleServiceTests extends ESTestCase { private DoExecuteDelegate clientDelegate; private ClusterService clusterService; private final DataStreamGlobalRetentionSettings globalRetentionSettings = DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() + ClusterSettings.createBuiltInClusterSettings() ); @Before diff --git a/muted-tests.yml b/muted-tests.yml index 68f1eea5c577e..67ce93f8703c4 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -351,27 +351,28 @@ tests: - class: org.elasticsearch.xpack.inference.services.cohere.CohereServiceTests method: testInfer_StreamRequest issue: https://github.com/elastic/elasticsearch/issues/114385 -- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT - method: test {p0=synonyms/60_synonym_rule_get/Synonym set not found} - issue: https://github.com/elastic/elasticsearch/issues/114432 -- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT - method: test {p0=synonyms/60_synonym_rule_get/Get a synonym rule} - issue: https://github.com/elastic/elasticsearch/issues/114443 -- class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT - method: test {p0=synonyms/60_synonym_rule_get/Synonym rule not found} - issue: https://github.com/elastic/elasticsearch/issues/114444 - class: org.elasticsearch.xpack.inference.InferenceRestIT method: test {p0=inference/30_semantic_text_inference/Calculates embeddings using the default ELSER 2 endpoint} issue: https://github.com/elastic/elasticsearch/issues/114412 - class: org.elasticsearch.xpack.inference.InferenceRestIT method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint} issue: https://github.com/elastic/elasticsearch/issues/114376 -- class: org.elasticsearch.search.retriever.StandardRetrieverBuilderParsingTests - method: testRewrite - issue: https://github.com/elastic/elasticsearch/issues/114466 - class: org.elasticsearch.search.retriever.RankDocsRetrieverBuilderTests method: testRewrite issue: https://github.com/elastic/elasticsearch/issues/114467 +- class: org.elasticsearch.xpack.logsdb.LogsdbTestSuiteIT + issue: https://github.com/elastic/elasticsearch/issues/114471 +- class: org.elasticsearch.packaging.test.DockerTests + method: test022InstallPluginsFromLocalArchive + issue: https://github.com/elastic/elasticsearch/issues/111063 +- class: org.elasticsearch.smoketest.DocsClientYamlTestSuiteIT + method: test {yaml=reference/esql/esql-across-clusters/line_196} + issue: https://github.com/elastic/elasticsearch/issues/114488 +- class: org.elasticsearch.gradle.internal.PublishPluginFuncTest + issue: https://github.com/elastic/elasticsearch/issues/114492 +- class: org.elasticsearch.xpack.inference.DefaultElserIT + method: testInferCreatesDefaultElser + issue: https://github.com/elastic/elasticsearch/issues/114503 # Examples: # diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/60_synonym_rule_get.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/60_synonym_rule_get.yml index 0a4a32448666e..2a7c8aff89d8e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/60_synonym_rule_get.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/synonyms/60_synonym_rule_get.yml @@ -15,7 +15,8 @@ setup: id: "test-id-3" - do: cluster.health: - index: .synonyms-2 + index: .synonyms + timeout: 1m wait_for_status: green --- diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java new file mode 100644 index 0000000000000..537ace30e88f0 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverTelemetryIT.java @@ -0,0 +1,151 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest; +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse; +import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0) +public class RetrieverTelemetryIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "test_index"; + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + @Before + public void setup() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("vector") + .field("type", "dense_vector") + .field("dims", 1) + .field("index", true) + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", "hnsw") + .endObject() + .endObject() + .startObject("text") + .field("type", "text") + .endObject() + .startObject("integer") + .field("type", "integer") + .endObject() + .startObject("topic") + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + + assertAcked(prepareCreate(INDEX_NAME).setMapping(builder)); + ensureGreen(INDEX_NAME); + } + + private void performSearch(SearchSourceBuilder source) throws IOException { + Request request = new Request("GET", INDEX_NAME + "/_search"); + request.setJsonEntity(Strings.toString(source)); + getRestClient().performRequest(request); + } + + public void testTelemetryForRetrievers() throws IOException { + + if (false == isRetrieverTelemetryEnabled()) { + return; + } + + // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` + { + performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + } + + // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under + // `queries` + { + performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2)))); + } + + // search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under + // `queries` + { + performSearch( + new SearchSourceBuilder().retriever( + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + ) + ); + } + + // search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term" + // under `queries` + { + performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo")))); + } + + // search#5 - t + // his will record 1 entry for "knn" in `sections` + { + performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + } + + // search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` + { + performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())); + } + + // cluster stats + { + SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats(); + assertEquals(6, stats.getTotalSearchCount()); + + assertThat(stats.getSectionsUsage().size(), equalTo(3)); + assertThat(stats.getSectionsUsage().get("retriever"), equalTo(4L)); + assertThat(stats.getSectionsUsage().get("query"), equalTo(1L)); + assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L)); + + assertThat(stats.getRetrieversUsage().size(), equalTo(2)); + assertThat(stats.getRetrieversUsage().get("standard"), equalTo(3L)); + assertThat(stats.getRetrieversUsage().get("knn"), equalTo(1L)); + + assertThat(stats.getQueryUsage().size(), equalTo(4)); + assertThat(stats.getQueryUsage().get("range"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("term"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("knn"), equalTo(1L)); + } + } + + private boolean isRetrieverTelemetryEnabled() throws IOException { + NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities( + new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats") + ).actionGet(); + return res != null && res.isSupported().orElse(false); + } +} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 11965abf1dcd2..70b748c86ec96 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -416,7 +416,6 @@ uses org.elasticsearch.internal.BuildExtension; uses org.elasticsearch.features.FeatureSpecification; uses org.elasticsearch.plugins.internal.LoggingDataProvider; - uses org.elasticsearch.cluster.metadata.DataStreamFactoryRetention; provides org.elasticsearch.features.FeatureSpecification with diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 78fddad603cab..d136aac8a2e5c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -238,6 +238,7 @@ static TransportVersion def(int id) { public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0); public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0); public static final TransportVersion SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS = def(8_764_00_0); + public static final TransportVersion RETRIEVERS_TELEMETRY_ADDED = def(8_765_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java index da78e04d2b0d7..0f6c56fd21bd7 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.TransportVersions.RETRIEVERS_TELEMETRY_ADDED; import static org.elasticsearch.TransportVersions.V_8_12_0; /** @@ -34,6 +35,7 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment { private final Map queries; private final Map rescorers; private final Map sections; + private final Map retrievers; /** * Creates a new empty stats instance, that will get additional stats added through {@link #add(SearchUsageStats)} @@ -43,17 +45,25 @@ public SearchUsageStats() { this.queries = new HashMap<>(); this.sections = new HashMap<>(); this.rescorers = new HashMap<>(); + this.retrievers = new HashMap<>(); } /** * Creates a new stats instance with the provided info. The expectation is that when a new instance is created using * this constructor, the provided stats are final and won't be modified further. */ - public SearchUsageStats(Map queries, Map rescorers, Map sections, long totalSearchCount) { + public SearchUsageStats( + Map queries, + Map rescorers, + Map sections, + Map retrievers, + long totalSearchCount + ) { this.totalSearchCount = totalSearchCount; this.queries = queries; this.sections = sections; this.rescorers = rescorers; + this.retrievers = retrievers; } public SearchUsageStats(StreamInput in) throws IOException { @@ -61,6 +71,7 @@ public SearchUsageStats(StreamInput in) throws IOException { this.sections = in.readMap(StreamInput::readLong); this.totalSearchCount = in.readVLong(); this.rescorers = in.getTransportVersion().onOrAfter(V_8_12_0) ? in.readMap(StreamInput::readLong) : Map.of(); + this.retrievers = in.getTransportVersion().onOrAfter(RETRIEVERS_TELEMETRY_ADDED) ? in.readMap(StreamInput::readLong) : Map.of(); } @Override @@ -72,6 +83,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(V_8_12_0)) { out.writeMap(rescorers, StreamOutput::writeLong); } + if (out.getTransportVersion().onOrAfter(RETRIEVERS_TELEMETRY_ADDED)) { + out.writeMap(retrievers, StreamOutput::writeLong); + } } /** @@ -81,6 +95,7 @@ public void add(SearchUsageStats stats) { stats.queries.forEach((query, count) -> queries.merge(query, count, Long::sum)); stats.rescorers.forEach((rescorer, count) -> rescorers.merge(rescorer, count, Long::sum)); stats.sections.forEach((query, count) -> sections.merge(query, count, Long::sum)); + stats.retrievers.forEach((query, count) -> retrievers.merge(query, count, Long::sum)); this.totalSearchCount += stats.totalSearchCount; } @@ -95,6 +110,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.map(rescorers); builder.field("sections"); builder.map(sections); + builder.field("retrievers"); + builder.map(retrievers); } builder.endObject(); return builder; @@ -112,6 +129,10 @@ public Map getSectionsUsage() { return Collections.unmodifiableMap(sections); } + public Map getRetrieversUsage() { + return Collections.unmodifiableMap(retrievers); + } + public long getTotalSearchCount() { return totalSearchCount; } @@ -128,12 +149,13 @@ public boolean equals(Object o) { return totalSearchCount == that.totalSearchCount && queries.equals(that.queries) && rescorers.equals(that.rescorers) - && sections.equals(that.sections); + && sections.equals(that.sections) + && retrievers.equals(that.retrievers); } @Override public int hashCode() { - return Objects.hash(totalSearchCount, queries, rescorers, sections); + return Objects.hash(totalSearchCount, queries, rescorers, sections, retrievers); } @Override diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFactoryRetention.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFactoryRetention.java deleted file mode 100644 index 656c63889a79d..0000000000000 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamFactoryRetention.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.cluster.metadata; - -import org.elasticsearch.common.settings.ClusterSettings; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.plugins.PluginsService; - -/** - * Holds the factory retention configuration. Factory retention is the global retention configuration meant to be - * used if a user hasn't provided other retention configuration via {@link DataStreamGlobalRetention} metadata in the - * cluster state. - * @deprecated This interface is deprecated, please use {@link DataStreamGlobalRetentionSettings}. - */ -@Deprecated -public interface DataStreamFactoryRetention { - - @Nullable - TimeValue getMaxRetention(); - - @Nullable - TimeValue getDefaultRetention(); - - /** - * @return true, if at least one of the two settings is not null, false otherwise. - */ - default boolean isDefined() { - return getMaxRetention() != null || getDefaultRetention() != null; - } - - /** - * Applies any post constructor initialisation, for example, listening to cluster setting changes. - */ - void init(ClusterSettings clusterSettings); - - /** - * Loads a single instance of a DataStreamFactoryRetention from the {@link PluginsService} and finalises the - * initialisation by calling {@link DataStreamFactoryRetention#init(ClusterSettings)} - */ - static DataStreamFactoryRetention load(PluginsService pluginsService, ClusterSettings clusterSettings) { - DataStreamFactoryRetention factoryRetention = pluginsService.loadSingletonServiceProvider( - DataStreamFactoryRetention.class, - DataStreamFactoryRetention::emptyFactoryRetention - ); - factoryRetention.init(clusterSettings); - return factoryRetention; - } - - /** - * Returns empty factory global retention settings. - */ - static DataStreamFactoryRetention emptyFactoryRetention() { - return new DataStreamFactoryRetention() { - - @Override - public TimeValue getMaxRetention() { - return null; - } - - @Override - public TimeValue getDefaultRetention() { - return null; - } - - @Override - public void init(ClusterSettings clusterSettings) { - - } - }; - } -} diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettings.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettings.java index fd4df18551c30..9e7256d6818bb 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettings.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettings.java @@ -26,8 +26,6 @@ * The global retention settings apply to non-system data streams that are managed by the data stream lifecycle. They consist of: * - The default retention which applies to data streams that do not have a retention defined. * - The max retention which applies to all data streams that do not have retention or their retention has exceeded this value. - *

- * Temporarily, we fall back to {@link DataStreamFactoryRetention} to facilitate a smooth transition to these settings. */ public class DataStreamGlobalRetentionSettings { @@ -84,42 +82,35 @@ public Iterator> settings() { Setting.Property.Dynamic ); - private final DataStreamFactoryRetention factoryRetention; - @Nullable private volatile TimeValue defaultRetention; @Nullable private volatile TimeValue maxRetention; - private DataStreamGlobalRetentionSettings(DataStreamFactoryRetention factoryRetention) { - this.factoryRetention = factoryRetention; + private DataStreamGlobalRetentionSettings() { + } @Nullable public TimeValue getMaxRetention() { - return shouldFallbackToFactorySettings() ? factoryRetention.getMaxRetention() : maxRetention; + return maxRetention; } @Nullable public TimeValue getDefaultRetention() { - return shouldFallbackToFactorySettings() ? factoryRetention.getDefaultRetention() : defaultRetention; + return defaultRetention; } public boolean areDefined() { return getDefaultRetention() != null || getMaxRetention() != null; } - private boolean shouldFallbackToFactorySettings() { - return defaultRetention == null && maxRetention == null; - } - /** * Creates an instance and initialises the cluster settings listeners * @param clusterSettings it will register the cluster settings listeners to monitor for changes - * @param factoryRetention for migration purposes, it will be removed shortly */ - public static DataStreamGlobalRetentionSettings create(ClusterSettings clusterSettings, DataStreamFactoryRetention factoryRetention) { - DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings = new DataStreamGlobalRetentionSettings(factoryRetention); + public static DataStreamGlobalRetentionSettings create(ClusterSettings clusterSettings) { + DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings = new DataStreamGlobalRetentionSettings(); clusterSettings.initializeAndWatch(DATA_STREAMS_DEFAULT_RETENTION_SETTING, dataStreamGlobalRetentionSettings::setDefaultRetention); clusterSettings.initializeAndWatch(DATA_STREAMS_MAX_RETENTION_SETTING, dataStreamGlobalRetentionSettings::setMaxRetention); return dataStreamGlobalRetentionSettings; diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index 16b180a82acc6..16dcdf00fbbef 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -44,7 +44,6 @@ import org.elasticsearch.cluster.coordination.MasterHistoryService; import org.elasticsearch.cluster.coordination.StableMasterHealthIndicatorService; import org.elasticsearch.cluster.features.NodeFeaturesFixupListener; -import org.elasticsearch.cluster.metadata.DataStreamFactoryRetention; import org.elasticsearch.cluster.metadata.DataStreamGlobalRetentionSettings; import org.elasticsearch.cluster.metadata.IndexMetadataVerifier; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -621,8 +620,7 @@ private DataStreamGlobalRetentionSettings createDataStreamServicesAndGlobalReten MetadataCreateIndexService metadataCreateIndexService ) { DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings = DataStreamGlobalRetentionSettings.create( - clusterService.getClusterSettings(), - DataStreamFactoryRetention.load(pluginsService, clusterService.getClusterSettings()) + clusterService.getClusterSettings() ); modules.bindToInstance(DataStreamGlobalRetentionSettings.class, dataStreamGlobalRetentionSettings); modules.bindToInstance( diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterStatsAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterStatsAction.java index 6427e6139a7aa..63bd4523f9bd1 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterStatsAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestClusterStatsAction.java @@ -32,7 +32,8 @@ public class RestClusterStatsAction extends BaseRestHandler { private static final Set SUPPORTED_CAPABILITIES = Set.of( "human-readable-total-docs-size", "verbose-dense-vector-mapping-stats", - "ccs-stats" + "ccs-stats", + "retrievers-usage-stats" ); private static final Set SUPPORTED_QUERY_PARAMETERS = Set.of("include_remotes", "nodeId", REST_TIMEOUT_PARAM); diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index fc0cb72bb82e0..9f94ec1452019 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -1409,6 +1409,7 @@ private SearchSourceBuilder parseXContent( parser, new RetrieverParserContext(searchUsage, clusterSupportsFeature) ); + searchUsage.trackSectionUsage(RETRIEVER.getPreferredName()); } else if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { if (subSearchSourceBuilders.isEmpty() == false) { throw new IllegalArgumentException( diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 1c6f8c4a7ce44..882d44adb79c3 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -62,11 +62,11 @@ protected static void declareBaseParserFields( String name, AbstractObjectParser parser ) { - parser.declareObjectArray((r, v) -> r.preFilterQueryBuilders = v, (p, c) -> { - QueryBuilder preFilterQueryBuilder = AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage); - c.trackSectionUsage(name + ":" + PRE_FILTER_FIELD.getPreferredName()); - return preFilterQueryBuilder; - }, PRE_FILTER_FIELD); + parser.declareObjectArray( + (r, v) -> r.preFilterQueryBuilders = v, + (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage), + PRE_FILTER_FIELD + ); parser.declareString(RetrieverBuilder::retrieverName, NAME_FIELD); parser.declareFloat(RetrieverBuilder::minScore, MIN_SCORE_FIELD); } @@ -138,7 +138,7 @@ protected static RetrieverBuilder parseInnerRetrieverBuilder(XContentParser pars throw new ParsingException(new XContentLocation(nonfe.getLineNumber(), nonfe.getColumnNumber()), message, nonfe); } - context.trackSectionUsage(retrieverName); + context.trackRetrieverUsage(retrieverName); if (parser.currentToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverParserContext.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverParserContext.java index 1f9444fc284fc..bdf3f8a194546 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverParserContext.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverParserContext.java @@ -37,6 +37,10 @@ public void trackRescorerUsage(String name) { searchUsage.trackRescorerUsage(name); } + public void trackRetrieverUsage(String name) { + searchUsage.trackRetrieverUsage(name); + } + public boolean clusterSupportsFeature(NodeFeature nodeFeature) { return clusterSupportsFeature != null && clusterSupportsFeature.test(nodeFeature); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java index 108aafd8c7771..4e875a97fdfc4 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java @@ -55,36 +55,28 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements static { PARSER.declareObject((r, v) -> r.queryBuilder = v, (p, c) -> { QueryBuilder queryBuilder = AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage); - c.trackSectionUsage(NAME + ":" + QUERY_FIELD.getPreferredName()); return queryBuilder; }, QUERY_FIELD); - PARSER.declareField((r, v) -> r.searchAfterBuilder = v, (p, c) -> { - SearchAfterBuilder searchAfterBuilder = SearchAfterBuilder.fromXContent(p); - c.trackSectionUsage(NAME + ":" + SEARCH_AFTER_FIELD.getPreferredName()); - return searchAfterBuilder; - }, SEARCH_AFTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); - - PARSER.declareField((r, v) -> r.terminateAfter = v, (p, c) -> { - int terminateAfter = p.intValue(); - c.trackSectionUsage(NAME + ":" + TERMINATE_AFTER_FIELD.getPreferredName()); - return terminateAfter; - }, TERMINATE_AFTER_FIELD, ObjectParser.ValueType.INT); - - PARSER.declareField((r, v) -> r.sortBuilders = v, (p, c) -> { - List> sortBuilders = SortBuilder.fromXContent(p); - c.trackSectionUsage(NAME + ":" + SORT_FIELD.getPreferredName()); - return sortBuilders; - }, SORT_FIELD, ObjectParser.ValueType.OBJECT_ARRAY); - - PARSER.declareField((r, v) -> r.collapseBuilder = v, (p, c) -> { - CollapseBuilder collapseBuilder = CollapseBuilder.fromXContent(p); - if (collapseBuilder.getField() != null) { - c.trackSectionUsage(COLLAPSE_FIELD.getPreferredName()); - } - return collapseBuilder; - }, COLLAPSE_FIELD, ObjectParser.ValueType.OBJECT); - + PARSER.declareField( + (r, v) -> r.searchAfterBuilder = v, + (p, c) -> SearchAfterBuilder.fromXContent(p), + SEARCH_AFTER_FIELD, + ObjectParser.ValueType.OBJECT_ARRAY + ); + PARSER.declareField((r, v) -> r.terminateAfter = v, (p, c) -> p.intValue(), TERMINATE_AFTER_FIELD, ObjectParser.ValueType.INT); + PARSER.declareField( + (r, v) -> r.sortBuilders = v, + (p, c) -> SortBuilder.fromXContent(p), + SORT_FIELD, + ObjectParser.ValueType.OBJECT_ARRAY + ); + PARSER.declareField( + (r, v) -> r.collapseBuilder = v, + (p, c) -> CollapseBuilder.fromXContent(p), + COLLAPSE_FIELD, + ObjectParser.ValueType.OBJECT + ); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -121,29 +113,29 @@ private StandardRetrieverBuilder(StandardRetrieverBuilder clone) { @Override public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { boolean changed = false; - List> newSortBuilders = null; + List> rewrittenSortBuilders = null; if (sortBuilders != null) { - newSortBuilders = new ArrayList<>(sortBuilders.size()); + rewrittenSortBuilders = new ArrayList<>(sortBuilders.size()); for (var sort : sortBuilders) { var newSort = sort.rewrite(ctx); - newSortBuilders.add(newSort); - changed = newSort != sort; + rewrittenSortBuilders.add(newSort); + changed |= newSort != sort; } } var rewrittenFilters = rewritePreFilters(ctx); changed |= rewrittenFilters != preFilterQueryBuilders; - QueryBuilder queryBuilderRewrite = null; + QueryBuilder rewrittenQuery = null; if (queryBuilder != null) { - queryBuilderRewrite = queryBuilder.rewrite(ctx); - changed |= queryBuilderRewrite != queryBuilder; + rewrittenQuery = queryBuilder.rewrite(ctx); + changed |= rewrittenQuery != queryBuilder; } if (changed) { var rewritten = new StandardRetrieverBuilder(this); - rewritten.sortBuilders = newSortBuilders; - rewritten.preFilterQueryBuilders = preFilterQueryBuilders; - rewritten.queryBuilder = queryBuilderRewrite; + rewritten.sortBuilders = rewrittenSortBuilders; + rewritten.preFilterQueryBuilders = rewrittenFilters; + rewritten.queryBuilder = rewrittenQuery; return rewritten; } return this; diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index 27db078788506..0fb767c5789f9 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -1059,8 +1059,9 @@ private Executor getInternalSendExceptionExecutor(Executor handlerExecutor) { if (lifecycle.stoppedOrClosed()) { // too late to try and dispatch anywhere else, let's just use the calling thread return EsExecutors.DIRECT_EXECUTOR_SERVICE; - } else if (handlerExecutor == EsExecutors.DIRECT_EXECUTOR_SERVICE) { - // if the handler is non-forking then dispatch to GENERIC to avoid a possible stack overflow + } else if (handlerExecutor == EsExecutors.DIRECT_EXECUTOR_SERVICE && enableStackOverflowAvoidance) { + // If the handler is non-forking and stack overflow protection is enabled then dispatch to GENERIC + // Otherwise we let the handler deal with any potential stack overflow (this is the default) return threadPool.generic(); } else { return handlerExecutor; diff --git a/server/src/main/java/org/elasticsearch/usage/SearchUsage.java b/server/src/main/java/org/elasticsearch/usage/SearchUsage.java index 7df7a302f1b19..e35594fb161ac 100644 --- a/server/src/main/java/org/elasticsearch/usage/SearchUsage.java +++ b/server/src/main/java/org/elasticsearch/usage/SearchUsage.java @@ -20,6 +20,7 @@ public final class SearchUsage { private final Set queries = new HashSet<>(); private final Set rescorers = new HashSet<>(); private final Set sections = new HashSet<>(); + private final Set retrievers = new HashSet<>(); /** * Track the usage of the provided query @@ -42,6 +43,13 @@ public void trackRescorerUsage(String name) { rescorers.add(name); } + /** + * Track retrieve usage + */ + public void trackRetrieverUsage(String retriever) { + retrievers.add(retriever); + } + /** * Returns the query types that have been used at least once in the tracked search request */ @@ -62,4 +70,11 @@ public Set getRescorerUsage() { public Set getSectionsUsage() { return Collections.unmodifiableSet(sections); } + + /** + * Returns the retriever names that have been used at least once in the tracked search request + */ + public Set getRetrieverUsage() { + return Collections.unmodifiableSet(retrievers); + } } diff --git a/server/src/main/java/org/elasticsearch/usage/SearchUsageHolder.java b/server/src/main/java/org/elasticsearch/usage/SearchUsageHolder.java index 652dfbdd20c57..ef802723cf164 100644 --- a/server/src/main/java/org/elasticsearch/usage/SearchUsageHolder.java +++ b/server/src/main/java/org/elasticsearch/usage/SearchUsageHolder.java @@ -27,6 +27,7 @@ public final class SearchUsageHolder { private final Map queriesUsage = new ConcurrentHashMap<>(); private final Map rescorersUsage = new ConcurrentHashMap<>(); private final Map sectionsUsage = new ConcurrentHashMap<>(); + private final Map retrieversUsage = new ConcurrentHashMap<>(); SearchUsageHolder() {} @@ -44,6 +45,9 @@ public void updateUsage(SearchUsage searchUsage) { for (String rescorer : searchUsage.getRescorerUsage()) { rescorersUsage.computeIfAbsent(rescorer, q -> new LongAdder()).increment(); } + for (String retriever : searchUsage.getRetrieverUsage()) { + retrieversUsage.computeIfAbsent(retriever, q -> new LongAdder()).increment(); + } } /** @@ -56,10 +60,13 @@ public SearchUsageStats getSearchUsageStats() { sectionsUsage.forEach((query, adder) -> sectionsUsageMap.put(query, adder.longValue())); Map rescorersUsageMap = Maps.newMapWithExpectedSize(rescorersUsage.size()); rescorersUsage.forEach((query, adder) -> rescorersUsageMap.put(query, adder.longValue())); + Map retrieversUsageMap = Maps.newMapWithExpectedSize(retrieversUsage.size()); + retrieversUsage.forEach((retriever, adder) -> retrieversUsageMap.put(retriever, adder.longValue())); return new SearchUsageStats( Collections.unmodifiableMap(queriesUsageMap), Collections.unmodifiableMap(rescorersUsageMap), Collections.unmodifiableMap(sectionsUsageMap), + Collections.unmodifiableMap(retrieversUsageMap), totalSearchCount.longValue() ); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java index a705514f56592..89ccd4ab63d7f 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java @@ -10,6 +10,7 @@ package org.elasticsearch.action.admin.cluster.stats; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable.Reader; import org.elasticsearch.test.AbstractWireSerializingTestCase; @@ -43,9 +44,12 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase RETRIEVERS = List.of("standard", "knn", "rrf", "random", "text_similarity_reranker"); + @Override protected Reader instanceReader() { return SearchUsageStats::new; @@ -75,6 +79,14 @@ private static Map randomRescorerUsage(int size) { return rescorerUsage; } + private static Map randomRetrieversUsage(int size) { + Map retrieversUsage = new HashMap<>(); + while (retrieversUsage.size() < size) { + retrieversUsage.put(randomFrom(RETRIEVERS), randomLongBetween(1, Long.MAX_VALUE)); + } + return retrieversUsage; + } + @Override protected SearchUsageStats createTestInstance() { if (randomBoolean()) { @@ -84,6 +96,7 @@ protected SearchUsageStats createTestInstance() { randomQueryUsage(randomIntBetween(0, QUERY_TYPES.size())), randomRescorerUsage(randomIntBetween(0, RESCORER_TYPES.size())), randomSectionsUsage(randomIntBetween(0, SECTIONS.size())), + randomRetrieversUsage(randomIntBetween(0, RETRIEVERS.size())), randomLongBetween(10, Long.MAX_VALUE) ); } @@ -96,26 +109,38 @@ protected SearchUsageStats mutateInstance(SearchUsageStats instance) { randomValueOtherThan(instance.getQueryUsage(), () -> randomQueryUsage(randomIntBetween(0, QUERY_TYPES.size()))), instance.getRescorerUsage(), instance.getSectionsUsage(), + instance.getRetrieversUsage(), instance.getTotalSearchCount() ); case 1 -> new SearchUsageStats( instance.getQueryUsage(), randomValueOtherThan(instance.getRescorerUsage(), () -> randomRescorerUsage(randomIntBetween(0, RESCORER_TYPES.size()))), instance.getSectionsUsage(), + instance.getRetrieversUsage(), instance.getTotalSearchCount() ); case 2 -> new SearchUsageStats( instance.getQueryUsage(), instance.getRescorerUsage(), randomValueOtherThan(instance.getSectionsUsage(), () -> randomSectionsUsage(randomIntBetween(0, SECTIONS.size()))), + instance.getRetrieversUsage(), instance.getTotalSearchCount() ); - default -> new SearchUsageStats( + case 3 -> new SearchUsageStats( instance.getQueryUsage(), instance.getRescorerUsage(), instance.getSectionsUsage(), - randomLongBetween(10, Long.MAX_VALUE) + randomValueOtherThan(instance.getRetrieversUsage(), () -> randomSectionsUsage(randomIntBetween(0, SECTIONS.size()))), + instance.getTotalSearchCount() ); + case 4 -> new SearchUsageStats( + instance.getQueryUsage(), + instance.getRescorerUsage(), + instance.getSectionsUsage(), + instance.getRetrieversUsage(), + randomValueOtherThan(instance.getTotalSearchCount(), () -> randomLongBetween(10, Long.MAX_VALUE)) + ); + default -> throw new IllegalStateException("Unexpected value: " + i); }; } @@ -126,7 +151,9 @@ public void testAdd() { assertEquals(Map.of(), searchUsageStats.getSectionsUsage()); assertEquals(0, searchUsageStats.getTotalSearchCount()); - searchUsageStats.add(new SearchUsageStats(Map.of("match", 10L), Map.of("query", 5L), Map.of("query", 10L), 10L)); + searchUsageStats.add( + new SearchUsageStats(Map.of("match", 10L), Map.of("query", 5L), Map.of("query", 10L), Map.of("knn", 10L), 10L) + ); assertEquals(Map.of("match", 10L), searchUsageStats.getQueryUsage()); assertEquals(Map.of("query", 10L), searchUsageStats.getSectionsUsage()); assertEquals(Map.of("query", 5L), searchUsageStats.getRescorerUsage()); @@ -137,19 +164,28 @@ public void testAdd() { Map.of("term", 1L, "match", 1L), Map.of("query", 5L, "learning_to_rank", 2L), Map.of("query", 10L, "knn", 1L), + Map.of("knn", 10L, "rrf", 2L), 10L ) ); assertEquals(Map.of("match", 11L, "term", 1L), searchUsageStats.getQueryUsage()); assertEquals(Map.of("query", 20L, "knn", 1L), searchUsageStats.getSectionsUsage()); assertEquals(Map.of("query", 10L, "learning_to_rank", 2L), searchUsageStats.getRescorerUsage()); + assertEquals(Map.of("knn", 20L, "rrf", 2L), searchUsageStats.getRetrieversUsage()); assertEquals(20L, searchUsageStats.getTotalSearchCount()); } public void testToXContent() throws IOException { - SearchUsageStats searchUsageStats = new SearchUsageStats(Map.of("term", 1L), Map.of("query", 2L), Map.of("query", 10L), 10L); + SearchUsageStats searchUsageStats = new SearchUsageStats( + Map.of("term", 1L), + Map.of("query", 2L), + Map.of("query", 10L), + Map.of("knn", 10L), + 10L + ); assertEquals( - "{\"search\":{\"total\":10,\"queries\":{\"term\":1},\"rescorers\":{\"query\":2},\"sections\":{\"query\":10}}}", + "{\"search\":{\"total\":10,\"queries\":{\"term\":1},\"rescorers\":{\"query\":2}," + + "\"sections\":{\"query\":10},\"retrievers\":{\"knn\":10}}}", Strings.toString(searchUsageStats) ); } @@ -161,8 +197,9 @@ public void testSerializationBWC() throws IOException { for (TransportVersion version : TransportVersionUtils.allReleasedVersions()) { SearchUsageStats testInstance = new SearchUsageStats( randomQueryUsage(QUERY_TYPES.size()), - Map.of(), + version.onOrAfter(TransportVersions.V_8_12_0) ? randomRescorerUsage(RESCORER_TYPES.size()) : Map.of(), randomSectionsUsage(SECTIONS.size()), + version.onOrAfter(TransportVersions.RETRIEVERS_TELEMETRY_ADDED) ? randomRetrieversUsage(RETRIEVERS.size()) : Map.of(), randomLongBetween(0, Long.MAX_VALUE) ); assertSerialization(testInstance, version); diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/template/reservedstate/ReservedComposableIndexTemplateActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/template/reservedstate/ReservedComposableIndexTemplateActionTests.java index f0f4e37f31c19..3e49cbe774eef 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/template/reservedstate/ReservedComposableIndexTemplateActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/template/reservedstate/ReservedComposableIndexTemplateActionTests.java @@ -18,7 +18,6 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.ComposableIndexTemplate; -import org.elasticsearch.cluster.metadata.DataStreamFactoryRetention; import org.elasticsearch.cluster.metadata.DataStreamGlobalRetentionSettings; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; @@ -94,10 +93,7 @@ public void setup() throws IOException { doReturn(mapperService).when(indexService).mapperService(); doReturn(indexService).when(indicesService).createIndex(any(), any(), anyBoolean()); - globalRetentionSettings = DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() - ); + globalRetentionSettings = DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings()); templateService = new MetadataIndexTemplateService( clusterService, mock(MetadataCreateIndexService.class), diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettingsTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettingsTests.java index 9de653d29e686..17fa520ad1c4a 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettingsTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamGlobalRetentionSettingsTests.java @@ -22,46 +22,16 @@ public class DataStreamGlobalRetentionSettingsTests extends ESTestCase { public void testDefaults() { DataStreamGlobalRetentionSettings globalRetentionSettings = DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() + ClusterSettings.createBuiltInClusterSettings() ); assertThat(globalRetentionSettings.getDefaultRetention(), nullValue()); assertThat(globalRetentionSettings.getMaxRetention(), nullValue()); - - // Fallback to factory settings - TimeValue maxFactoryValue = randomPositiveTimeValue(); - TimeValue defaultFactoryValue = randomPositiveTimeValue(); - DataStreamGlobalRetentionSettings withFactorySettings = DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - new DataStreamFactoryRetention() { - @Override - public TimeValue getMaxRetention() { - return maxFactoryValue; - } - - @Override - public TimeValue getDefaultRetention() { - return defaultFactoryValue; - } - - @Override - public void init(ClusterSettings clusterSettings) { - - } - } - ); - - assertThat(withFactorySettings.getDefaultRetention(), equalTo(defaultFactoryValue)); - assertThat(withFactorySettings.getMaxRetention(), equalTo(maxFactoryValue)); } public void testMonitorsDefaultRetention() { ClusterSettings clusterSettings = ClusterSettings.createBuiltInClusterSettings(); - DataStreamGlobalRetentionSettings globalRetentionSettings = DataStreamGlobalRetentionSettings.create( - clusterSettings, - DataStreamFactoryRetention.emptyFactoryRetention() - ); + DataStreamGlobalRetentionSettings globalRetentionSettings = DataStreamGlobalRetentionSettings.create(clusterSettings); // Test valid update TimeValue newDefaultRetention = TimeValue.timeValueDays(randomIntBetween(1, 10)); @@ -91,10 +61,7 @@ public void testMonitorsDefaultRetention() { public void testMonitorsMaxRetention() { ClusterSettings clusterSettings = ClusterSettings.createBuiltInClusterSettings(); - DataStreamGlobalRetentionSettings globalRetentionSettings = DataStreamGlobalRetentionSettings.create( - clusterSettings, - DataStreamFactoryRetention.emptyFactoryRetention() - ); + DataStreamGlobalRetentionSettings globalRetentionSettings = DataStreamGlobalRetentionSettings.create(clusterSettings); // Test valid update TimeValue newMaxRetention = TimeValue.timeValueDays(randomIntBetween(10, 30)); @@ -121,7 +88,7 @@ public void testMonitorsMaxRetention() { public void testCombinationValidation() { ClusterSettings clusterSettings = ClusterSettings.createBuiltInClusterSettings(); - DataStreamGlobalRetentionSettings.create(clusterSettings, DataStreamFactoryRetention.emptyFactoryRetention()); + DataStreamGlobalRetentionSettings.create(clusterSettings); // Test invalid update Settings newInvalidSettings = Settings.builder() diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamLifecycleWithRetentionWarningsTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamLifecycleWithRetentionWarningsTests.java index d7f10f484165b..27198f51ed97e 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamLifecycleWithRetentionWarningsTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamLifecycleWithRetentionWarningsTests.java @@ -141,10 +141,7 @@ public void testUpdatingLifecycleOnADataStream() { MetadataDataStreamsService metadataDataStreamsService = new MetadataDataStreamsService( mock(ClusterService.class), mock(IndicesService.class), - DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(settingsWithDefaultRetention), - DataStreamFactoryRetention.emptyFactoryRetention() - ) + DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings(settingsWithDefaultRetention)) ); ClusterState after = metadataDataStreamsService.updateDataLifecycle(before, List.of(dataStream), DataStreamLifecycle.DEFAULT); @@ -281,10 +278,7 @@ public void testValidateLifecycleInComponentTemplate() throws Exception { xContentRegistry(), EmptySystemIndices.INSTANCE, new IndexSettingProviders(Set.of()), - DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(settingsWithDefaultRetention), - DataStreamFactoryRetention.emptyFactoryRetention() - ) + DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings(settingsWithDefaultRetention)) ); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java index 6ef516f67014c..92c1103c950c0 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java @@ -402,10 +402,7 @@ public void testUpdateLifecycle() { MetadataDataStreamsService service = new MetadataDataStreamsService( mock(ClusterService.class), mock(IndicesService.class), - DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() - ) + DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings()) ); { // Remove lifecycle diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java index 8d4b04746e7a4..873b185e6be28 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateServiceTests.java @@ -2496,10 +2496,7 @@ private static List putTemplate(NamedXContentRegistry xContentRegistr xContentRegistry, EmptySystemIndices.INSTANCE, new IndexSettingProviders(Set.of()), - DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() - ) + DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings()) ); final List throwables = new ArrayList<>(); @@ -2563,10 +2560,7 @@ private MetadataIndexTemplateService getMetadataIndexTemplateService() { xContentRegistry(), EmptySystemIndices.INSTANCE, new IndexSettingProviders(Set.of()), - DataStreamGlobalRetentionSettings.create( - ClusterSettings.createBuiltInClusterSettings(), - DataStreamFactoryRetention.emptyFactoryRetention() - ) + DataStreamGlobalRetentionSettings.create(ClusterSettings.createBuiltInClusterSettings()) ); } diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceLifecycleTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceLifecycleTests.java index a4a6ef6c5c5f2..b631eddc5173b 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceLifecycleTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceLifecycleTests.java @@ -149,8 +149,13 @@ public void testInternalSendExceptionForksToHandlerExecutor() { } } - public void testInternalSendExceptionForksToGenericIfHandlerDoesNotFork() { - try (var nodeA = new TestNode("node-A")) { + public void testInternalSendExceptionForksToGenericIfHandlerDoesNotForkAndStackOverflowProtectionEnabled() { + try ( + var nodeA = new TestNode( + "node-A", + Settings.builder().put(TransportService.ENABLE_STACK_OVERFLOW_AVOIDANCE.getKey(), true).build() + ) + ) { final var future = new PlainActionFuture(); nodeA.transportService.sendRequest( nodeA.getThrowingConnection(), @@ -165,6 +170,33 @@ public void testInternalSendExceptionForksToGenericIfHandlerDoesNotFork() { assertEquals("simulated exception in sendRequest", getSendRequestException(future, IOException.class).getMessage()); } + assertWarnings( + "[transport.enable_stack_protection] setting was deprecated in Elasticsearch and will be removed in a future release." + ); + } + + public void testInternalSendExceptionWithNonForkingResponseHandlerCompletesListenerInline() { + try (var nodeA = new TestNode("node-A")) { + final Thread callingThread = Thread.currentThread(); + assertEquals( + "simulated exception in sendRequest", + safeAwaitAndUnwrapFailure( + IOException.class, + TransportResponse.Empty.class, + l -> nodeA.transportService.sendRequest( + nodeA.getThrowingConnection(), + TestNode.randomActionName(random()), + new EmptyRequest(), + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>( + ActionListener.runBefore(l, () -> assertSame(callingThread, Thread.currentThread())), + unusedReader(), + EsExecutors.DIRECT_EXECUTOR_SERVICE + ) + ) + ).getMessage() + ); + } } public void testInternalSendExceptionForcesExecutionOnHandlerExecutor() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java index 6c28c6f3053ab..0028508e87f32 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java @@ -458,13 +458,13 @@ static RoleDescriptor kibanaSystem(String name) { TransportUpdateSettingsAction.TYPE.name() ) .build(), - + // security entity analytics indices RoleDescriptor.IndicesPrivileges.builder().indices("risk-score.risk-*").privileges("all").build(), RoleDescriptor.IndicesPrivileges.builder() .indices(".asset-criticality.asset-criticality-*") .privileges("create_index", "manage", "read", "write") .build(), - + RoleDescriptor.IndicesPrivileges.builder().indices(".entities.v1.latest.security*").privileges("read").build(), // For cloud_defend usageCollection RoleDescriptor.IndicesPrivileges.builder() .indices("logs-cloud_defend.*", "metrics-cloud_defend.*") diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index 1543ba039dc6d..319e67512c7ac 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -279,9 +279,9 @@ private Object valueMapper(CsvTestUtils.Type type, Object value) { } return values; } else if (value instanceof Double d) { - return new BigDecimal(d).round(new MathContext(10, RoundingMode.DOWN)).doubleValue(); + return new BigDecimal(d).round(new MathContext(7, RoundingMode.DOWN)).doubleValue(); } else if (value instanceof String s) { - return new BigDecimal(s).round(new MathContext(10, RoundingMode.DOWN)).doubleValue(); + return new BigDecimal(s).round(new MathContext(7, RoundingMode.DOWN)).doubleValue(); } } return value.toString(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 3ec39d1b0ac4b..ce2a1d7a5f660 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -60,6 +60,7 @@ import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.elasticsearch.xpack.esql.action.EsqlQueryAction; import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec; @@ -206,13 +207,19 @@ public void execute( ); long start = configuration.getQueryStartTimeNanos(); String local = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + /* + * Grab the output attributes here, so we can pass them to + * the listener without holding on to a reference to the + * entire plan. + */ + List outputAttributes = physicalPlan.output(); try ( Releasable ignored = exchangeSource.addEmptySink(); // this is the top level ComputeListener called once at the end (e.g., once all clusters have finished for a CCS) var computeListener = ComputeListener.create(local, transportService, rootTask, execInfo, start, listener.map(r -> { long tookTimeNanos = System.nanoTime() - configuration.getQueryStartTimeNanos(); execInfo.overallTook(new TimeValue(tookTimeNanos, TimeUnit.NANOSECONDS)); - return new Result(physicalPlan.output(), collectedPages, r.getProfiles(), execInfo); + return new Result(outputAttributes, collectedPages, r.getProfiles(), execInfo); })) ) { // run compute on the coordinator diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 3ddaab12eca14..8bccf6e7d1022 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -64,7 +64,11 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder }); static { - PARSER.declareNamedObject(constructorArg(), (p, c, n) -> p.namedObject(RetrieverBuilder.class, n, c), RETRIEVER_FIELD); + PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { + RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c); + c.trackRetrieverUsage(innerRetriever.getName()); + return innerRetriever; + }, RETRIEVER_FIELD); PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD); PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD); PARSER.declareString(constructorArg(), FIELD_FIELD); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java index 32301bf9efea9..478f3b2f33c93 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java @@ -117,7 +117,10 @@ public void testParserDefaults() throws IOException { }"""; try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) { - TextSimilarityRankRetrieverBuilder parsed = TextSimilarityRankRetrieverBuilder.PARSER.parse(parser, null); + TextSimilarityRankRetrieverBuilder parsed = TextSimilarityRankRetrieverBuilder.PARSER.parse( + parser, + new RetrieverParserContext(new SearchUsage(), nf -> true) + ); assertEquals(DEFAULT_RANK_WINDOW_SIZE, parsed.rankWindowSize()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java new file mode 100644 index 0000000000000..916703446995d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rank.textsimilarity; + +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest; +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse; +import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0) +public class TextSimilarityRankRetrieverTelemetryTests extends ESIntegTestCase { + + private static final String INDEX_NAME = "test_index"; + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + @Override + protected Collection> nodePlugins() { + return List.of(InferencePlugin.class, XPackPlugin.class, TextSimilarityTestPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put("xpack.license.self_generated.type", "trial") + .build(); + } + + @Before + public void setup() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("vector") + .field("type", "dense_vector") + .field("dims", 1) + .field("index", true) + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", "hnsw") + .endObject() + .endObject() + .startObject("text") + .field("type", "text") + .endObject() + .startObject("integer") + .field("type", "integer") + .endObject() + .startObject("topic") + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + + assertAcked(prepareCreate(INDEX_NAME).setMapping(builder)); + ensureGreen(INDEX_NAME); + } + + private void performSearch(SearchSourceBuilder source) throws IOException { + Request request = new Request("GET", INDEX_NAME + "/_search"); + request.setJsonEntity(Strings.toString(source)); + getRestClient().performRequest(request); + } + + public void testTelemetryForRRFRetriever() throws IOException { + + if (false == isRetrieverTelemetryEnabled()) { + return; + } + + // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` + { + performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + } + + // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under + // `queries` + { + performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2)))); + } + + // search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under + // `queries` + { + performSearch( + new SearchSourceBuilder().retriever( + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + ) + ); + } + + // search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term" + // under `queries` + { + performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo")))); + } + + // search#5 - this will record 1 entry for "retriever" in `sections`, and 1 for "text_similarity_reranker" under `retrievers`, as + // well as + // 1 "standard" under `retrievers`, and eventually 1 for "match" under `queries` + { + performSearch( + new SearchSourceBuilder().retriever( + new TextSimilarityRankRetrieverBuilder( + new StandardRetrieverBuilder(QueryBuilders.matchQuery("text", "foo")), + "some_inference_id", + "some_inference_text", + "some_field", + 10 + ) + ) + ); + } + + // search#6 - this will record 1 entry for "knn" in `sections` + { + performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + } + + // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` + { + performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())); + } + + // cluster stats + { + SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats(); + assertEquals(7, stats.getTotalSearchCount()); + + assertThat(stats.getSectionsUsage().size(), equalTo(3)); + assertThat(stats.getSectionsUsage().get("retriever"), equalTo(5L)); + assertThat(stats.getSectionsUsage().get("query"), equalTo(1L)); + assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L)); + + assertThat(stats.getRetrieversUsage().size(), equalTo(3)); + assertThat(stats.getRetrieversUsage().get("standard"), equalTo(4L)); + assertThat(stats.getRetrieversUsage().get("knn"), equalTo(1L)); + assertThat(stats.getRetrieversUsage().get("text_similarity_reranker"), equalTo(1L)); + + assertThat(stats.getQueryUsage().size(), equalTo(5)); + assertThat(stats.getQueryUsage().get("range"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("term"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("match"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("knn"), equalTo(1L)); + } + } + + private boolean isRetrieverTelemetryEnabled() throws IOException { + NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities( + new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats") + ).actionGet(); + return res != null && res.isSupported().orElse(false); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangeDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangeDetector.java new file mode 100644 index 0000000000000..e771fb3b94568 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangeDetector.java @@ -0,0 +1,560 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.changepoint; + +import org.apache.commons.math3.distribution.UniformRealDistribution; +import org.apache.commons.math3.random.RandomGeneratorFactory; +import org.apache.commons.math3.special.Beta; +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; +import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.xpack.ml.aggs.MlAggsHelper; + +import java.util.Arrays; +import java.util.Random; +import java.util.Set; +import java.util.function.IntToDoubleFunction; +import java.util.stream.IntStream; + +/** + * Detects whether a time series is stationary or changing + * (either continuously or at a specific change point). + */ +public class ChangeDetector { + + private static final int MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST = 500; + private static final int MAXIMUM_CANDIDATE_CHANGE_POINTS = 1000; + + private static final KolmogorovSmirnovTest KOLMOGOROV_SMIRNOV_TEST = new KolmogorovSmirnovTest(); + + private static final Logger logger = LogManager.getLogger(ChangeDetector.class); + + private final MlAggsHelper.DoubleBucketValues bucketValues; + private final double[] values; + + ChangeDetector(MlAggsHelper.DoubleBucketValues bucketValues) { + this.bucketValues = bucketValues; + this.values = bucketValues.getValues(); + } + + ChangeType detect(double minBucketsPValue) { + // This was obtained by simulating the test power for a fixed size effect as a + // function of the bucket value count. + double pValueThreshold = minBucketsPValue * Math.exp(-0.04 * (values.length - 2 * (ChangePointDetector.MINIMUM_BUCKETS + 1))); + return testForChange(pValueThreshold).changeType(bucketValues, slope(values)); + } + + private TestStats testForChange(double pValueThreshold) { + + int[] candidateChangePoints = computeCandidateChangePoints(values); + logger.trace("candidatePoints: [{}]", Arrays.toString(candidateChangePoints)); + + double[] valuesWeights = outlierWeights(values); + logger.trace("values: [{}]", Arrays.toString(values)); + logger.trace("valuesWeights: [{}]", Arrays.toString(valuesWeights)); + RunningStats dataRunningStats = RunningStats.from(values, i -> valuesWeights[i]); + DataStats dataStats = new DataStats( + dataRunningStats.count(), + dataRunningStats.mean(), + dataRunningStats.variance(), + candidateChangePoints.length + ); + logger.trace("dataStats: [{}]", dataStats); + TestStats stationary = new TestStats(Type.STATIONARY, 1.0, dataStats.var(), 1.0, dataStats); + + if (dataStats.varianceZeroToWorkingPrecision()) { + return stationary; + } + + TestStats trendVsStationary = testTrendVs(stationary, values, valuesWeights); + logger.trace("trend vs stationary: [{}]", trendVsStationary); + + TestStats best = stationary; + Set discoveredChangePoints = Sets.newHashSetWithExpectedSize(4); + if (trendVsStationary.accept(pValueThreshold)) { + // Check if there is a change in the trend. + TestStats trendChangeVsTrend = testTrendChangeVs(trendVsStationary, values, valuesWeights, candidateChangePoints); + discoveredChangePoints.add(trendChangeVsTrend.changePoint()); + logger.trace("trend change vs trend: [{}]", trendChangeVsTrend); + + if (trendChangeVsTrend.accept(pValueThreshold)) { + // Check if modeling a trend change adds much over modeling a step change. + best = testVsStepChange(trendChangeVsTrend, values, valuesWeights, candidateChangePoints, pValueThreshold); + } else { + best = trendVsStationary; + } + + } else { + // Check if there is a step change. + TestStats stepChangeVsStationary = testStepChangeVs(stationary, values, valuesWeights, candidateChangePoints); + discoveredChangePoints.add(stepChangeVsStationary.changePoint()); + logger.trace("step change vs stationary: [{}]", stepChangeVsStationary); + + if (stepChangeVsStationary.accept(pValueThreshold)) { + // Check if modeling a trend change adds much over modeling a step change. + TestStats trendChangeVsStepChange = testTrendChangeVs(stepChangeVsStationary, values, valuesWeights, candidateChangePoints); + discoveredChangePoints.add(stepChangeVsStationary.changePoint()); + logger.trace("trend change vs step change: [{}]", trendChangeVsStepChange); + if (trendChangeVsStepChange.accept(pValueThreshold)) { + best = trendChangeVsStepChange; + } else { + best = stepChangeVsStationary; + } + + } else { + // Check if there is a trend change. + TestStats trendChangeVsStationary = testTrendChangeVs(stationary, values, valuesWeights, candidateChangePoints); + discoveredChangePoints.add(stepChangeVsStationary.changePoint()); + logger.trace("trend change vs stationary: [{}]", trendChangeVsStationary); + if (trendChangeVsStationary.accept(pValueThreshold)) { + best = trendChangeVsStationary; + } + } + } + + logger.trace("best: [{}]", best.pValueVsStationary()); + + // We're not very confident in the change point, so check if a distribution change + // fits the data better. + if (best.pValueVsStationary() > 1e-5) { + TestStats distChange = testDistributionChange(dataStats, values, valuesWeights, candidateChangePoints, discoveredChangePoints); + logger.trace("distribution change: [{}]", distChange); + if (distChange.pValue() < Math.min(pValueThreshold, 0.1 * best.pValueVsStationary())) { + best = distChange; + } + } + + return best; + } + + private int[] computeCandidateChangePoints(double[] values) { + int minValues = Math.max((int) (0.1 * values.length + 0.5), ChangePointDetector.MINIMUM_BUCKETS); + if (values.length - 2 * minValues <= MAXIMUM_CANDIDATE_CHANGE_POINTS) { + return IntStream.range(minValues, values.length - minValues).toArray(); + } else { + int step = (int) Math.ceil((double) (values.length - 2 * minValues) / MAXIMUM_CANDIDATE_CHANGE_POINTS); + return IntStream.range(minValues, values.length - minValues).filter(i -> i % step == 0).toArray(); + } + } + + private double[] outlierWeights(double[] values) { + int i = (int) Math.ceil(0.025 * values.length); + double[] weights = Arrays.copyOf(values, values.length); + Arrays.sort(weights); + // We have to be careful here if we have a lot of duplicate values. To avoid marking + // runs of duplicates as outliers we define outliers to be the smallest (largest) + // value strictly less (greater) than the value at i (values.length - i - 1). This + // means if i lands in a run of duplicates the entire run will be marked as inliers. + double a = weights[i]; + double b = weights[values.length - i - 1]; + for (int j = 0; j < values.length; j++) { + if (values[j] <= b && values[j] >= a) { + weights[j] = 1.0; + } else { + weights[j] = 0.01; + } + } + return weights; + } + + private double slope(double[] values) { + SimpleRegression regression = new SimpleRegression(); + for (int i = 0; i < values.length; i++) { + regression.addData(i, values[i]); + } + return regression.getSlope(); + } + + private static double independentTrialsPValue(double pValue, int nTrials) { + return pValue > 1e-10 ? 1.0 - Math.pow(1.0 - pValue, nTrials) : nTrials * pValue; + } + + private TestStats testTrendVs(TestStats H0, double[] values, double[] weights) { + LeastSquaresOnlineRegression allLeastSquares = new LeastSquaresOnlineRegression(2); + for (int i = 0; i < values.length; i++) { + allLeastSquares.add(i, values[i], weights[i]); + } + double vTrend = H0.dataStats().var() * (1.0 - allLeastSquares.rSquared()); + double pValue = fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vTrend, 3.0); + return new TestStats(Type.NON_STATIONARY, pValue, vTrend, 3.0, H0.dataStats()); + } + + private TestStats testStepChangeVs(TestStats H0, double[] values, double[] weights, int[] candidateChangePoints) { + + double vStep = Double.MAX_VALUE; + int changePoint = -1; + + // Initialize running stats so that they are only missing the individual changepoint values + RunningStats lowerRange = new RunningStats(); + RunningStats upperRange = new RunningStats(); + upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length); + lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]); + double mean = H0.dataStats().mean(); + int last = candidateChangePoints[0]; + for (int cp : candidateChangePoints) { + lowerRange.addValues(values, i -> weights[i], last, cp); + upperRange.removeValues(values, i -> weights[i], last, cp); + last = cp; + double nl = lowerRange.count(); + double nu = upperRange.count(); + double ml = lowerRange.mean(); + double mu = upperRange.mean(); + double vl = lowerRange.variance(); + double vu = upperRange.variance(); + double v = (nl * vl + nu * vu) / (nl + nu); + if (v < vStep) { + vStep = v; + changePoint = cp; + } + } + + double pValue = independentTrialsPValue( + fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vStep, 2.0), + candidateChangePoints.length + ); + + return new TestStats(Type.STEP_CHANGE, pValue, vStep, 2.0, changePoint, H0.dataStats()); + } + + private TestStats testTrendChangeVs(TestStats H0, double[] values, double[] weights, int[] candidateChangePoints) { + + double vChange = Double.MAX_VALUE; + int changePoint = -1; + + // Initialize running stats so that they are only missing the individual changepoint values + RunningStats lowerRange = new RunningStats(); + RunningStats upperRange = new RunningStats(); + lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]); + upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length); + LeastSquaresOnlineRegression lowerLeastSquares = new LeastSquaresOnlineRegression(2); + LeastSquaresOnlineRegression upperLeastSquares = new LeastSquaresOnlineRegression(2); + int first = candidateChangePoints[0]; + int last = candidateChangePoints[0]; + for (int i = 0; i < candidateChangePoints[0]; i++) { + lowerLeastSquares.add(i, values[i], weights[i]); + } + for (int i = candidateChangePoints[0]; i < values.length; i++) { + upperLeastSquares.add(i - first, values[i], weights[i]); + } + for (int cp : candidateChangePoints) { + for (int i = last; i < cp; i++) { + lowerRange.addValue(values[i], weights[i]); + upperRange.removeValue(values[i], weights[i]); + lowerLeastSquares.add(i, values[i], weights[i]); + upperLeastSquares.remove(i - first, values[i], weights[i]); + } + last = cp; + double nl = lowerRange.count(); + double nu = upperRange.count(); + double rl = lowerLeastSquares.rSquared(); + double ru = upperLeastSquares.rSquared(); + double vl = lowerRange.variance() * (1.0 - rl); + double vu = upperRange.variance() * (1.0 - ru); + double v = (nl * vl + nu * vu) / (nl + nu); + if (v < vChange) { + vChange = v; + changePoint = cp; + } + } + + double pValue = independentTrialsPValue( + fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vChange, 6.0), + candidateChangePoints.length + ); + + return new TestStats(Type.TREND_CHANGE, pValue, vChange, 6.0, changePoint, H0.dataStats()); + } + + private TestStats testVsStepChange( + TestStats trendChange, + double[] values, + double[] weights, + int[] candidateChangePoints, + double pValueThreshold + ) { + DataStats dataStats = trendChange.dataStats(); + TestStats stationary = new TestStats(Type.STATIONARY, 1.0, dataStats.var(), 1.0, dataStats); + TestStats stepChange = testStepChangeVs(stationary, values, weights, candidateChangePoints); + double n = dataStats.nValues(); + double pValue = fTestNestedPValue(n, stepChange.var(), 2.0, trendChange.var(), 6.0); + return pValue < pValueThreshold ? trendChange : stepChange; + } + + private static double fTestNestedPValue(double n, double vNull, double pNull, double vAlt, double pAlt) { + if (vAlt == vNull) { + return 1.0; + } + if (vAlt == 0.0) { + return 0.0; + } + double F = (vNull - vAlt) / (pAlt - pNull) * (n - pAlt) / vAlt; + double sf = fDistribSf(pAlt - pNull, n - pAlt, F); + return Math.min(2 * sf, 1.0); + } + + private static int lowerBound(int[] x, int start, int end, int xs) { + int retVal = Arrays.binarySearch(x, start, end, xs); + if (retVal < 0) { + retVal = -1 - retVal; + } + return retVal; + } + + private SampleData sample(double[] values, double[] weights, Set changePoints) { + Integer[] adjChangePoints = changePoints.toArray(new Integer[changePoints.size()]); + if (values.length <= MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST) { + return new SampleData(values, weights, adjChangePoints); + } + + // Just want repeatable random numbers. + Random rng = new Random(126832678); + UniformRealDistribution uniform = new UniformRealDistribution(RandomGeneratorFactory.createRandomGenerator(rng), 0.0, 0.99999); + + // Fisher–Yates shuffle (why isn't this in Arrays?). + int[] choice = IntStream.range(0, values.length).toArray(); + for (int i = 0; i < MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST; ++i) { + int index = i + (int) Math.floor(uniform.sample() * (values.length - i)); + int tmp = choice[i]; + choice[i] = choice[index]; + choice[index] = tmp; + } + + double[] sample = new double[MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST]; + double[] sampleWeights = new double[MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST]; + Arrays.sort(choice, 0, MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST); + for (int i = 0; i < MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST; ++i) { + sample[i] = values[choice[i]]; + sampleWeights[i] = weights[choice[i]]; + } + for (int i = 0; i < adjChangePoints.length; ++i) { + adjChangePoints[i] = lowerBound(choice, 0, MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST, adjChangePoints[i].intValue()); + } + + return new SampleData(sample, sampleWeights, adjChangePoints); + } + + private TestStats testDistributionChange( + DataStats stats, + double[] values, + double[] weights, + int[] candidateChangePoints, + Set discoveredChangePoints + ) { + + double maxDiff = 0.0; + int changePoint = -1; + + // Initialize running stats so that they are only missing the individual changepoint values + RunningStats lowerRange = new RunningStats(); + RunningStats upperRange = new RunningStats(); + upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length); + lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]); + int last = candidateChangePoints[0]; + for (int cp : candidateChangePoints) { + lowerRange.addValues(values, i -> weights[i], last, cp); + upperRange.removeValues(values, i -> weights[i], last, cp); + last = cp; + double scale = Math.min(cp, values.length - cp); + double meanDiff = Math.abs(lowerRange.mean() - upperRange.mean()); + double stdDiff = Math.abs(lowerRange.std() - upperRange.std()); + double diff = scale * (meanDiff + stdDiff); + if (diff >= maxDiff) { + maxDiff = diff; + changePoint = cp; + } + } + discoveredChangePoints.add(changePoint); + + // Note that statistical tests become increasingly powerful as the number of samples + // increases. We are not interested in detecting visually small distribution changes + // in splits of long windows so we randomly downsample the data if it is too large + // before we run the tests. + SampleData sampleData = sample(values, weights, discoveredChangePoints); + final double[] sampleValues = sampleData.values(); + final double[] sampleWeights = sampleData.weights(); + + double pValue = 1; + for (int cp : sampleData.changePoints()) { + double[] x = Arrays.copyOfRange(sampleValues, 0, cp); + double[] y = Arrays.copyOfRange(sampleValues, cp, sampleValues.length); + double statistic = KOLMOGOROV_SMIRNOV_TEST.kolmogorovSmirnovStatistic(x, y); + double ksTestPValue = KOLMOGOROV_SMIRNOV_TEST.exactP(statistic, x.length, y.length, false); + if (ksTestPValue < pValue) { + changePoint = cp; + pValue = ksTestPValue; + } + } + + // We start to get false positives if we have too many candidate change points. This + // is the classic p-value hacking problem. However, the Sidak style correction we use + // elsewhere is too conservative because test statistics for different split positions + // are strongly correlated. We assume that we have some effective number of independent + // trials equal to f * n for f < 1. Simulation shows the f = 1/50 yields low Type I + // error rates. + pValue = independentTrialsPValue(pValue, (sampleValues.length + 49) / 50); + logger.trace("distribution change p-value: [{}]", pValue); + + return new TestStats(Type.DISTRIBUTION_CHANGE, pValue, changePoint, stats); + } + + private static double fDistribSf(double numeratorDegreesOfFreedom, double denominatorDegreesOfFreedom, double x) { + if (x <= 0) { + return 1; + } + if (Double.isInfinite(x) || Double.isNaN(x)) { + return 0; + } + + return Beta.regularizedBeta( + denominatorDegreesOfFreedom / (denominatorDegreesOfFreedom + numeratorDegreesOfFreedom * x), + 0.5 * denominatorDegreesOfFreedom, + 0.5 * numeratorDegreesOfFreedom + ); + } + + private enum Type { + STATIONARY, + NON_STATIONARY, + STEP_CHANGE, + TREND_CHANGE, + DISTRIBUTION_CHANGE + } + + private record SampleData(double[] values, double[] weights, Integer[] changePoints) {} + + private record DataStats(double nValues, double mean, double var, int nCandidateChangePoints) { + boolean varianceZeroToWorkingPrecision() { + // Our variance calculation is only accurate to ulp(length * mean)^(1/2), + // i.e. we compute it using the difference of squares method and don't use + // the Kahan correction. We treat anything as zero to working precision as + // zero. We should at some point switch to a more numerically stable approach + // for computing data statistics. + return var < Math.sqrt(Math.ulp(2.0 * nValues * mean)); + } + + @Override + public String toString() { + return "DataStats{nValues=" + nValues + ", mean=" + mean + ", var=" + var + ", nCandidates=" + nCandidateChangePoints + "}"; + } + } + + private record TestStats(Type type, double pValue, double var, double nParams, int changePoint, DataStats dataStats) { + TestStats(Type type, double pValue, int changePoint, DataStats dataStats) { + this(type, pValue, 0.0, 0.0, changePoint, dataStats); + } + + TestStats(Type type, double pValue, double var, double nParams, DataStats dataStats) { + this(type, pValue, var, nParams, -1, dataStats); + } + + boolean accept(double pValueThreshold) { + // Check the change is: + // 1. Statistically significant. + // 2. That we explain enough of the data variance overall. + return pValue < pValueThreshold && rSquared() >= 0.5; + } + + double rSquared() { + return 1.0 - var / dataStats.var(); + } + + double pValueVsStationary() { + return independentTrialsPValue( + fTestNestedPValue(dataStats.nValues(), dataStats.var(), 1.0, var, nParams), + dataStats.nCandidateChangePoints() + ); + } + + ChangeType changeType(MlAggsHelper.DoubleBucketValues bucketValues, double slope) { + switch (type) { + case STATIONARY: + return new ChangeType.Stationary(); + case NON_STATIONARY: + return new ChangeType.NonStationary(pValueVsStationary(), rSquared(), slope < 0.0 ? "decreasing" : "increasing"); + case STEP_CHANGE: + return new ChangeType.StepChange(pValueVsStationary(), bucketValues.getBucketIndex(changePoint)); + case TREND_CHANGE: + return new ChangeType.TrendChange(pValueVsStationary(), rSquared(), bucketValues.getBucketIndex(changePoint)); + case DISTRIBUTION_CHANGE: + return new ChangeType.DistributionChange(pValue, bucketValues.getBucketIndex(changePoint)); + } + throw new RuntimeException("Unknown change type [" + type + "]."); + } + + @Override + public String toString() { + return "TestStats{" + + ("type=" + type) + + (", dataStats=" + dataStats) + + (", var=" + var) + + (", rSquared=" + rSquared()) + + (", pValue=" + pValue) + + (", nParams=" + nParams) + + (", changePoint=" + changePoint) + + '}'; + } + } + + private static class RunningStats { + double sumOfSqrs; + double sum; + double count; + + static RunningStats from(double[] values, IntToDoubleFunction weightFunction) { + return new RunningStats().addValues(values, weightFunction, 0, values.length); + } + + RunningStats() {} + + double count() { + return count; + } + + double mean() { + return sum / count; + } + + double variance() { + return Math.max((sumOfSqrs - ((sum * sum) / count)) / count, 0.0); + } + + double std() { + return Math.sqrt(variance()); + } + + RunningStats addValues(double[] value, IntToDoubleFunction weightFunction, int start, int end) { + for (int i = start; i < value.length && i < end; i++) { + addValue(value[i], weightFunction.applyAsDouble(i)); + } + return this; + } + + RunningStats addValue(double value, double weight) { + sumOfSqrs += (value * value * weight); + count += weight; + sum += (value * weight); + return this; + } + + RunningStats removeValue(double value, double weight) { + sumOfSqrs = Math.max(sumOfSqrs - value * value * weight, 0); + count = Math.max(count - weight, 0); + sum -= (value * weight); + return this; + } + + RunningStats removeValues(double[] value, IntToDoubleFunction weightFunction, int start, int end) { + for (int i = start; i < value.length && i < end; i++) { + removeValue(value[i], weightFunction.applyAsDouble(i)); + } + return this; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregator.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregator.java index faef29ff65070..d643a937180a1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregator.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregator.java @@ -7,86 +7,21 @@ package org.elasticsearch.xpack.ml.aggs.changepoint; -import org.apache.commons.math3.distribution.UniformRealDistribution; -import org.apache.commons.math3.exception.NotStrictlyPositiveException; -import org.apache.commons.math3.random.RandomGeneratorFactory; -import org.apache.commons.math3.special.Beta; -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; -import org.apache.commons.math3.stat.regression.SimpleRegression; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.pipeline.BucketHelpers; import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator; import org.elasticsearch.xpack.ml.aggs.MlAggsHelper; -import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType.Indeterminable; -import java.util.Arrays; import java.util.Map; import java.util.Optional; -import java.util.Random; -import java.util.Set; -import java.util.function.IntToDoubleFunction; -import java.util.stream.IntStream; import static org.elasticsearch.xpack.ml.aggs.MlAggsHelper.extractBucket; import static org.elasticsearch.xpack.ml.aggs.MlAggsHelper.extractDoubleBucketedValues; public class ChangePointAggregator extends SiblingPipelineAggregator { - private static final Logger logger = LogManager.getLogger(ChangePointAggregator.class); - - static final double P_VALUE_THRESHOLD = 0.01; - private static final int MINIMUM_BUCKETS = 10; - private static final int MAXIMUM_CANDIDATE_CHANGE_POINTS = 1000; - private static final int MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST = 500; - private static final KolmogorovSmirnovTest KOLMOGOROV_SMIRNOV_TEST = new KolmogorovSmirnovTest(); - - private static double changePValueThreshold(int nValues) { - // This was obtained by simulating the test power for a fixed size effect as a - // function of the bucket value count. - return P_VALUE_THRESHOLD * Math.exp(-0.04 * (double) (nValues - 2 * (MINIMUM_BUCKETS + 1))); - } - - private static int lowerBound(int[] x, int start, int end, int xs) { - int retVal = Arrays.binarySearch(x, start, end, xs); - if (retVal < 0) { - retVal = -1 - retVal; - } - return retVal; - } - - private record SampleData(double[] values, double[] weights, Integer[] changePoints) {} - - private record DataStats(double nValues, double mean, double var, int nCandidateChangePoints) { - boolean varianceZeroToWorkingPrecision() { - // Our variance calculation is only accurate to ulp(length * mean)^(1/2), - // i.e. we compute it using the difference of squares method and don't use - // the Kahan correction. We treat anything as zero to working precision as - // zero. We should at some point switch to a more numerically stable approach - // for computing data statistics. - return var < Math.sqrt(Math.ulp(2.0 * nValues * mean)); - } - - @Override - public String toString() { - return "DataStats{nValues=" + nValues + ", mean=" + mean + ", var=" + var + ", nCandidates=" + nCandidateChangePoints + "}"; - } - } - - static int[] computeCandidateChangePoints(double[] values) { - int minValues = Math.max((int) (0.1 * values.length + 0.5), MINIMUM_BUCKETS); - if (values.length - 2 * minValues <= MAXIMUM_CANDIDATE_CHANGE_POINTS) { - return IntStream.range(minValues, values.length - minValues).toArray(); - } else { - int step = (int) Math.ceil((double) (values.length - 2 * minValues) / MAXIMUM_CANDIDATE_CHANGE_POINTS); - return IntStream.range(minValues, values.length - minValues).filter(i -> i % step == 0).toArray(); - } - } - public ChangePointAggregator(String name, String bucketsPath, Map metadata) { super(name, new String[] { bucketsPath }, metadata); } @@ -108,30 +43,8 @@ public InternalAggregation doReduce(InternalAggregations aggregations, Aggregati ); } MlAggsHelper.DoubleBucketValues bucketValues = maybeBucketValues.get(); - if (bucketValues.getValues().length < (2 * MINIMUM_BUCKETS) + 2) { - return new InternalChangePointAggregation( - name(), - metadata(), - null, - new ChangeType.Indeterminable( - "not enough buckets to calculate change_point. Requires at least [" - + ((2 * MINIMUM_BUCKETS) + 2) - + "]; found [" - + bucketValues.getValues().length - + "]" - ) - ); - } - - ChangeType spikeOrDip = testForSpikeOrDip(bucketValues, P_VALUE_THRESHOLD); - - // Test for change step, trend and distribution changes. - ChangeType change = testForChange(bucketValues, changePValueThreshold(bucketValues.getValues().length)); - logger.trace("change p-value: [{}]", change.pValue()); - if (spikeOrDip.pValue() < change.pValue()) { - change = spikeOrDip; - } + ChangeType change = ChangePointDetector.getChangeType(bucketValues); ChangePointBucket changePointBucket = null; if (change.changePoint() >= 0) { @@ -142,503 +55,4 @@ public InternalAggregation doReduce(InternalAggregations aggregations, Aggregati return new InternalChangePointAggregation(name(), metadata(), changePointBucket, change); } - - static ChangeType testForSpikeOrDip(MlAggsHelper.DoubleBucketValues bucketValues, double pValueThreshold) { - try { - SpikeAndDipDetector detect = new SpikeAndDipDetector(bucketValues.getValues()); - ChangeType result = detect.at(pValueThreshold, bucketValues); - logger.trace("spike or dip p-value: [{}]", result.pValue()); - return result; - } catch (NotStrictlyPositiveException nspe) { - logger.debug("failure testing for dips and spikes", nspe); - } - return new Indeterminable("failure testing for dips and spikes"); - } - - static ChangeType testForChange(MlAggsHelper.DoubleBucketValues bucketValues, double pValueThreshold) { - double[] timeWindow = bucketValues.getValues(); - return testForChange(timeWindow, pValueThreshold).changeType(bucketValues, slope(timeWindow)); - } - - static TestStats testForChange(double[] timeWindow, double pValueThreshold) { - - int[] candidateChangePoints = computeCandidateChangePoints(timeWindow); - logger.trace("candidatePoints: [{}]", Arrays.toString(candidateChangePoints)); - - double[] timeWindowWeights = outlierWeights(timeWindow); - logger.trace("timeWindow: [{}]", Arrays.toString(timeWindow)); - logger.trace("timeWindowWeights: [{}]", Arrays.toString(timeWindowWeights)); - RunningStats dataRunningStats = RunningStats.from(timeWindow, i -> timeWindowWeights[i]); - DataStats dataStats = new DataStats( - dataRunningStats.count(), - dataRunningStats.mean(), - dataRunningStats.variance(), - candidateChangePoints.length - ); - logger.trace("dataStats: [{}]", dataStats); - TestStats stationary = new TestStats(Type.STATIONARY, 1.0, dataStats.var(), 1.0, dataStats); - - if (dataStats.varianceZeroToWorkingPrecision()) { - return stationary; - } - - TestStats trendVsStationary = testTrendVs(stationary, timeWindow, timeWindowWeights); - logger.trace("trend vs stationary: [{}]", trendVsStationary); - - TestStats best = stationary; - Set discoveredChangePoints = Sets.newHashSetWithExpectedSize(4); - if (trendVsStationary.accept(pValueThreshold)) { - // Check if there is a change in the trend. - TestStats trendChangeVsTrend = testTrendChangeVs(trendVsStationary, timeWindow, timeWindowWeights, candidateChangePoints); - discoveredChangePoints.add(trendChangeVsTrend.changePoint()); - logger.trace("trend change vs trend: [{}]", trendChangeVsTrend); - - if (trendChangeVsTrend.accept(pValueThreshold)) { - // Check if modeling a trend change adds much over modeling a step change. - best = testVsStepChange(trendChangeVsTrend, timeWindow, timeWindowWeights, candidateChangePoints, pValueThreshold); - } else { - best = trendVsStationary; - } - - } else { - // Check if there is a step change. - TestStats stepChangeVsStationary = testStepChangeVs(stationary, timeWindow, timeWindowWeights, candidateChangePoints); - discoveredChangePoints.add(stepChangeVsStationary.changePoint()); - logger.trace("step change vs stationary: [{}]", stepChangeVsStationary); - - if (stepChangeVsStationary.accept(pValueThreshold)) { - // Check if modeling a trend change adds much over modeling a step change. - TestStats trendChangeVsStepChange = testTrendChangeVs( - stepChangeVsStationary, - timeWindow, - timeWindowWeights, - candidateChangePoints - ); - discoveredChangePoints.add(stepChangeVsStationary.changePoint()); - logger.trace("trend change vs step change: [{}]", trendChangeVsStepChange); - if (trendChangeVsStepChange.accept(pValueThreshold)) { - best = trendChangeVsStepChange; - } else { - best = stepChangeVsStationary; - } - - } else { - // Check if there is a trend change. - TestStats trendChangeVsStationary = testTrendChangeVs(stationary, timeWindow, timeWindowWeights, candidateChangePoints); - discoveredChangePoints.add(stepChangeVsStationary.changePoint()); - logger.trace("trend change vs stationary: [{}]", trendChangeVsStationary); - if (trendChangeVsStationary.accept(pValueThreshold)) { - best = trendChangeVsStationary; - } - } - } - - logger.trace("best: [{}]", best.pValueVsStationary()); - - // We're not very confident in the change point, so check if a distribution change - // fits the data better. - if (best.pValueVsStationary() > 1e-5) { - TestStats distChange = testDistributionChange( - dataStats, - timeWindow, - timeWindowWeights, - candidateChangePoints, - discoveredChangePoints - ); - logger.trace("distribution change: [{}]", distChange); - if (distChange.pValue() < Math.min(pValueThreshold, 0.1 * best.pValueVsStationary())) { - best = distChange; - } - } - - return best; - } - - static double[] outlierWeights(double[] values) { - int i = (int) Math.ceil(0.025 * values.length); - double[] weights = Arrays.copyOf(values, values.length); - Arrays.sort(weights); - // We have to be careful here if we have a lot of duplicate values. To avoid marking - // runs of duplicates as outliers we define outliers to be the smallest (largest) - // value strictly less (greater) than the value at i (values.length - i - 1). This - // means if i lands in a run of duplicates the entire run will be marked as inliers. - double a = weights[i]; - double b = weights[values.length - i - 1]; - for (int j = 0; j < values.length; j++) { - if (values[j] <= b && values[j] >= a) { - weights[j] = 1.0; - } else { - weights[j] = 0.01; - } - } - return weights; - } - - static double slope(double[] values) { - SimpleRegression regression = new SimpleRegression(); - for (int i = 0; i < values.length; i++) { - regression.addData(i, values[i]); - } - return regression.getSlope(); - } - - static double independentTrialsPValue(double pValue, int nTrials) { - return pValue > 1e-10 ? 1.0 - Math.pow(1.0 - pValue, nTrials) : nTrials * pValue; - } - - static TestStats testTrendVs(TestStats H0, double[] values, double[] weights) { - LeastSquaresOnlineRegression allLeastSquares = new LeastSquaresOnlineRegression(2); - for (int i = 0; i < values.length; i++) { - allLeastSquares.add(i, values[i], weights[i]); - } - double vTrend = H0.dataStats().var() * (1.0 - allLeastSquares.rSquared()); - double pValue = fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vTrend, 3.0); - return new TestStats(Type.NON_STATIONARY, pValue, vTrend, 3.0, H0.dataStats()); - } - - static TestStats testStepChangeVs(TestStats H0, double[] values, double[] weights, int[] candidateChangePoints) { - - double vStep = Double.MAX_VALUE; - int changePoint = -1; - - // Initialize running stats so that they are only missing the individual changepoint values - RunningStats lowerRange = new RunningStats(); - RunningStats upperRange = new RunningStats(); - upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length); - lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]); - double mean = H0.dataStats().mean(); - int last = candidateChangePoints[0]; - for (int cp : candidateChangePoints) { - lowerRange.addValues(values, i -> weights[i], last, cp); - upperRange.removeValues(values, i -> weights[i], last, cp); - last = cp; - double nl = lowerRange.count(); - double nu = upperRange.count(); - double ml = lowerRange.mean(); - double mu = upperRange.mean(); - double vl = lowerRange.variance(); - double vu = upperRange.variance(); - double v = (nl * vl + nu * vu) / (nl + nu); - if (v < vStep) { - vStep = v; - changePoint = cp; - } - } - - double pValue = independentTrialsPValue( - fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vStep, 2.0), - candidateChangePoints.length - ); - - return new TestStats(Type.STEP_CHANGE, pValue, vStep, 2.0, changePoint, H0.dataStats()); - } - - static TestStats testTrendChangeVs(TestStats H0, double[] values, double[] weights, int[] candidateChangePoints) { - - double vChange = Double.MAX_VALUE; - int changePoint = -1; - - // Initialize running stats so that they are only missing the individual changepoint values - RunningStats lowerRange = new RunningStats(); - RunningStats upperRange = new RunningStats(); - lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]); - upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length); - LeastSquaresOnlineRegression lowerLeastSquares = new LeastSquaresOnlineRegression(2); - LeastSquaresOnlineRegression upperLeastSquares = new LeastSquaresOnlineRegression(2); - int first = candidateChangePoints[0]; - int last = candidateChangePoints[0]; - for (int i = 0; i < candidateChangePoints[0]; i++) { - lowerLeastSquares.add(i, values[i], weights[i]); - } - for (int i = candidateChangePoints[0]; i < values.length; i++) { - upperLeastSquares.add(i - first, values[i], weights[i]); - } - for (int cp : candidateChangePoints) { - for (int i = last; i < cp; i++) { - lowerRange.addValue(values[i], weights[i]); - upperRange.removeValue(values[i], weights[i]); - lowerLeastSquares.add(i, values[i], weights[i]); - upperLeastSquares.remove(i - first, values[i], weights[i]); - } - last = cp; - double nl = lowerRange.count(); - double nu = upperRange.count(); - double rl = lowerLeastSquares.rSquared(); - double ru = upperLeastSquares.rSquared(); - double vl = lowerRange.variance() * (1.0 - rl); - double vu = upperRange.variance() * (1.0 - ru); - double v = (nl * vl + nu * vu) / (nl + nu); - if (v < vChange) { - vChange = v; - changePoint = cp; - } - } - - double pValue = independentTrialsPValue( - fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vChange, 6.0), - candidateChangePoints.length - ); - - return new TestStats(Type.TREND_CHANGE, pValue, vChange, 6.0, changePoint, H0.dataStats()); - } - - static TestStats testVsStepChange( - TestStats trendChange, - double[] values, - double[] weights, - int[] candidateChangePoints, - double pValueThreshold - ) { - DataStats dataStats = trendChange.dataStats(); - TestStats stationary = new TestStats(Type.STATIONARY, 1.0, dataStats.var(), 1.0, dataStats); - TestStats stepChange = testStepChangeVs(stationary, values, weights, candidateChangePoints); - double n = dataStats.nValues(); - double pValue = fTestNestedPValue(n, stepChange.var(), 2.0, trendChange.var(), 6.0); - return pValue < pValueThreshold ? trendChange : stepChange; - } - - static double fTestNestedPValue(double n, double vNull, double pNull, double vAlt, double pAlt) { - if (vAlt == vNull) { - return 1.0; - } - if (vAlt == 0.0) { - return 0.0; - } - double F = (vNull - vAlt) / (pAlt - pNull) * (n - pAlt) / vAlt; - double sf = fDistribSf(pAlt - pNull, n - pAlt, F); - return Math.min(2 * sf, 1.0); - } - - static SampleData sample(double[] values, double[] weights, Set changePoints) { - Integer[] adjChangePoints = changePoints.toArray(new Integer[changePoints.size()]); - if (values.length <= MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST) { - return new SampleData(values, weights, adjChangePoints); - } - - // Just want repeatable random numbers. - Random rng = new Random(126832678); - UniformRealDistribution uniform = new UniformRealDistribution(RandomGeneratorFactory.createRandomGenerator(rng), 0.0, 0.99999); - - // Fisher–Yates shuffle (why isn't this in Arrays?). - int[] choice = IntStream.range(0, values.length).toArray(); - for (int i = 0; i < MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST; ++i) { - int index = i + (int) Math.floor(uniform.sample() * (values.length - i)); - int tmp = choice[i]; - choice[i] = choice[index]; - choice[index] = tmp; - } - - double[] sample = new double[MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST]; - double[] sampleWeights = new double[MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST]; - Arrays.sort(choice, 0, MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST); - for (int i = 0; i < MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST; ++i) { - sample[i] = values[choice[i]]; - sampleWeights[i] = weights[choice[i]]; - } - for (int i = 0; i < adjChangePoints.length; ++i) { - adjChangePoints[i] = lowerBound(choice, 0, MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST, adjChangePoints[i].intValue()); - } - - return new SampleData(sample, sampleWeights, adjChangePoints); - } - - static TestStats testDistributionChange( - DataStats stats, - double[] values, - double[] weights, - int[] candidateChangePoints, - Set discoveredChangePoints - ) { - - double maxDiff = 0.0; - int changePoint = -1; - - // Initialize running stats so that they are only missing the individual changepoint values - RunningStats lowerRange = new RunningStats(); - RunningStats upperRange = new RunningStats(); - upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length); - lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]); - int last = candidateChangePoints[0]; - for (int cp : candidateChangePoints) { - lowerRange.addValues(values, i -> weights[i], last, cp); - upperRange.removeValues(values, i -> weights[i], last, cp); - last = cp; - double scale = Math.min(cp, values.length - cp); - double meanDiff = Math.abs(lowerRange.mean() - upperRange.mean()); - double stdDiff = Math.abs(lowerRange.std() - upperRange.std()); - double diff = scale * (meanDiff + stdDiff); - if (diff >= maxDiff) { - maxDiff = diff; - changePoint = cp; - } - } - discoveredChangePoints.add(changePoint); - - // Note that statistical tests become increasingly powerful as the number of samples - // increases. We are not interested in detecting visually small distribution changes - // in splits of long windows so we randomly downsample the data if it is too large - // before we run the tests. - SampleData sampleData = sample(values, weights, discoveredChangePoints); - final double[] sampleValues = sampleData.values(); - final double[] sampleWeights = sampleData.weights(); - - double pValue = 1; - for (int cp : sampleData.changePoints()) { - double[] x = Arrays.copyOfRange(sampleValues, 0, cp); - double[] y = Arrays.copyOfRange(sampleValues, cp, sampleValues.length); - double statistic = KOLMOGOROV_SMIRNOV_TEST.kolmogorovSmirnovStatistic(x, y); - double ksTestPValue = KOLMOGOROV_SMIRNOV_TEST.exactP(statistic, x.length, y.length, false); - if (ksTestPValue < pValue) { - changePoint = cp; - pValue = ksTestPValue; - } - } - - // We start to get false positives if we have too many candidate change points. This - // is the classic p-value hacking problem. However, the Sidak style correction we use - // elsewhere is too conservative because test statistics for different split positions - // are strongly correlated. We assume that we have some effective number of independent - // trials equal to f * n for f < 1. Simulation shows the f = 1/50 yields low Type I - // error rates. - pValue = independentTrialsPValue(pValue, (sampleValues.length + 49) / 50); - logger.trace("distribution change p-value: [{}]", pValue); - - return new TestStats(Type.DISTRIBUTION_CHANGE, pValue, changePoint, stats); - } - - enum Type { - STATIONARY, - NON_STATIONARY, - STEP_CHANGE, - TREND_CHANGE, - DISTRIBUTION_CHANGE - } - - record TestStats(Type type, double pValue, double var, double nParams, int changePoint, DataStats dataStats) { - TestStats(Type type, double pValue, int changePoint, DataStats dataStats) { - this(type, pValue, 0.0, 0.0, changePoint, dataStats); - } - - TestStats(Type type, double pValue, double var, double nParams, DataStats dataStats) { - this(type, pValue, var, nParams, -1, dataStats); - } - - boolean accept(double pValueThreshold) { - // Check the change is: - // 1. Statistically significant. - // 2. That we explain enough of the data variance overall. - return pValue < pValueThreshold && rSquared() >= 0.5; - } - - double rSquared() { - return 1.0 - var / dataStats.var(); - } - - double pValueVsStationary() { - return independentTrialsPValue( - fTestNestedPValue(dataStats.nValues(), dataStats.var(), 1.0, var, nParams), - dataStats.nCandidateChangePoints() - ); - } - - ChangeType changeType(MlAggsHelper.DoubleBucketValues bucketValues, double slope) { - switch (type) { - case STATIONARY: - return new ChangeType.Stationary(); - case NON_STATIONARY: - return new ChangeType.NonStationary(pValueVsStationary(), rSquared(), slope < 0.0 ? "decreasing" : "increasing"); - case STEP_CHANGE: - return new ChangeType.StepChange(pValueVsStationary(), bucketValues.getBucketIndex(changePoint)); - case TREND_CHANGE: - return new ChangeType.TrendChange(pValueVsStationary(), rSquared(), bucketValues.getBucketIndex(changePoint)); - case DISTRIBUTION_CHANGE: - return new ChangeType.DistributionChange(pValue, bucketValues.getBucketIndex(changePoint)); - } - throw new RuntimeException("Unknown change type [" + type + "]."); - } - - @Override - public String toString() { - return "TestStats{" - + ("type=" + type) - + (", dataStats=" + dataStats) - + (", var=" + var) - + (", rSquared=" + rSquared()) - + (", pValue=" + pValue) - + (", nParams=" + nParams) - + (", changePoint=" + changePoint) - + '}'; - } - } - - static class RunningStats { - double sumOfSqrs; - double sum; - double count; - - static RunningStats from(double[] values, IntToDoubleFunction weightFunction) { - return new RunningStats().addValues(values, weightFunction, 0, values.length); - } - - RunningStats() {} - - double count() { - return count; - } - - double mean() { - return sum / count; - } - - double variance() { - return Math.max((sumOfSqrs - ((sum * sum) / count)) / count, 0.0); - } - - double std() { - return Math.sqrt(variance()); - } - - RunningStats addValues(double[] value, IntToDoubleFunction weightFunction, int start, int end) { - for (int i = start; i < value.length && i < end; i++) { - addValue(value[i], weightFunction.applyAsDouble(i)); - } - return this; - } - - RunningStats addValue(double value, double weight) { - sumOfSqrs += (value * value * weight); - count += weight; - sum += (value * weight); - return this; - } - - RunningStats removeValue(double value, double weight) { - sumOfSqrs = Math.max(sumOfSqrs - value * value * weight, 0); - count = Math.max(count - weight, 0); - sum -= (value * weight); - return this; - } - - RunningStats removeValues(double[] value, IntToDoubleFunction weightFunction, int start, int end) { - for (int i = start; i < value.length && i < end; i++) { - removeValue(value[i], weightFunction.applyAsDouble(i)); - } - return this; - } - } - - static double fDistribSf(double numeratorDegreesOfFreedom, double denominatorDegreesOfFreedom, double x) { - if (x <= 0) { - return 1; - } - if (Double.isInfinite(x) || Double.isNaN(x)) { - return 0; - } - - return Beta.regularizedBeta( - denominatorDegreesOfFreedom / (denominatorDegreesOfFreedom + numeratorDegreesOfFreedom * x), - 0.5 * denominatorDegreesOfFreedom, - 0.5 * numeratorDegreesOfFreedom - ); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointDetector.java new file mode 100644 index 0000000000000..d7708420994bb --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointDetector.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.changepoint; + +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.ml.aggs.MlAggsHelper; + +/** + * Detects whether a series of values has a change point, by running both + * ChangeDetector and SpikeAndDipDetector on it. This is the main entrypoint + * of change point detection. + */ +public class ChangePointDetector { + + private static final Logger logger = LogManager.getLogger(ChangePointDetector.class); + + static final double P_VALUE_THRESHOLD = 0.01; + static final int MINIMUM_BUCKETS = 10; + + /** + * Returns the ChangeType of a series of values. + */ + public static ChangeType getChangeType(MlAggsHelper.DoubleBucketValues bucketValues) { + if (bucketValues.getValues().length < (2 * MINIMUM_BUCKETS) + 2) { + return new ChangeType.Indeterminable( + "not enough buckets to calculate change_point. Requires at least [" + + ((2 * MINIMUM_BUCKETS) + 2) + + "]; found [" + + bucketValues.getValues().length + + "]" + ); + } + + ChangeType spikeOrDip; + try { + SpikeAndDipDetector detect = new SpikeAndDipDetector(bucketValues); + spikeOrDip = detect.detect(P_VALUE_THRESHOLD); + logger.trace("spike or dip p-value: [{}]", spikeOrDip.pValue()); + } catch (NotStrictlyPositiveException nspe) { + logger.debug("failure testing for dips and spikes", nspe); + spikeOrDip = new ChangeType.Indeterminable("failure testing for dips and spikes"); + } + + ChangeType change = new ChangeDetector(bucketValues).detect(P_VALUE_THRESHOLD); + logger.trace("change p-value: [{}]", change.pValue()); + + if (spikeOrDip.pValue() < change.pValue()) { + change = spikeOrDip; + } + return change; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetector.java index b628ea3324cf1..365ebe8562d6a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetector.java @@ -92,6 +92,7 @@ private double[] removeIf(ExcludedPredicate should, double[] values) { return newValues; } + private final MlAggsHelper.DoubleBucketValues bucketValues; private final int numValues; private final int dipIndex; private final int spikeIndex; @@ -100,7 +101,9 @@ private double[] removeIf(ExcludedPredicate should, double[] values) { private final KDE spikeTestKDE; private final KDE dipTestKDE; - SpikeAndDipDetector(double[] values) { + SpikeAndDipDetector(MlAggsHelper.DoubleBucketValues bucketValues) { + this.bucketValues = bucketValues; + double[] values = bucketValues.getValues(); numValues = values.length; @@ -135,7 +138,7 @@ private double[] removeIf(ExcludedPredicate should, double[] values) { spikeTestKDE = new KDE(spikeKDEValues, 1.36); } - ChangeType at(double pValueThreshold, MlAggsHelper.DoubleBucketValues bucketValues) { + ChangeType detect(double pValueThreshold) { if (dipIndex == -1 || spikeIndex == -1) { return new ChangeType.Indeterminable( "not enough buckets to check for dip or spike. Requires at least [3]; found [" + numValues + "]" diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangeDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangeDetectorTests.java new file mode 100644 index 0000000000000..75f668a96e77e --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangeDetectorTests.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.aggs.changepoint; + +import org.apache.commons.math3.distribution.GammaDistribution; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.random.RandomGeneratorFactory; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.xpack.ml.aggs.MlAggsHelper; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.DoubleStream; + +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThan; + +public class ChangeDetectorTests extends AggregatorTestCase { + + public void testStationaryFalsePositiveRate() { + NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); + int fp = 0; + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.generate(() -> 10 + normal.sample()).limit(40).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(1e-4); + fp += type instanceof ChangeType.Stationary ? 0 : 1; + } + assertThat(fp, lessThan(10)); + + fp = 0; + GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.generate(gamma::sample).limit(40).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(1e-4); + fp += type instanceof ChangeType.Stationary ? 0 : 1; + } + assertThat(fp, lessThan(10)); + } + + public void testSampledDistributionTestFalsePositiveRate() { + NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0.0, 1.0); + int fp = 0; + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.generate(() -> 10 + normal.sample()).limit(5000).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(1e-4); + fp += type instanceof ChangeType.Stationary ? 0 : 1; + } + assertThat(fp, lessThan(10)); + } + + public void testNonStationaryFalsePositiveRate() { + NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); + int fp = 0; + for (int i = 0; i < 100; i++) { + AtomicInteger j = new AtomicInteger(); + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.generate(() -> j.incrementAndGet() + normal.sample()).limit(40).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(1e-4); + fp += type instanceof ChangeType.NonStationary ? 0 : 1; + } + assertThat(fp, lessThan(10)); + + fp = 0; + GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); + for (int i = 0; i < 100; i++) { + AtomicInteger j = new AtomicInteger(); + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.generate(() -> j.incrementAndGet() + gamma.sample()).limit(40).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(1e-4); + fp += type instanceof ChangeType.NonStationary ? 0 : 1; + } + assertThat(fp, lessThan(10)); + } + + public void testStepChangePower() { + NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); + int tp = 0; + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.concat( + DoubleStream.generate(() -> normal.sample()).limit(20), + DoubleStream.generate(() -> 10 + normal.sample()).limit(20) + ).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + tp += type instanceof ChangeType.StepChange ? 1 : 0; + } + assertThat(tp, greaterThan(80)); + + tp = 0; + GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.concat( + DoubleStream.generate(() -> gamma.sample()).limit(20), + DoubleStream.generate(() -> 10 + gamma.sample()).limit(20) + ).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + tp += type instanceof ChangeType.StepChange ? 1 : 0; + } + assertThat(tp, greaterThan(80)); + } + + public void testTrendChangePower() { + NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); + int tp = 0; + for (int i = 0; i < 100; i++) { + AtomicInteger j = new AtomicInteger(); + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.concat( + DoubleStream.generate(() -> j.incrementAndGet() + normal.sample()).limit(20), + DoubleStream.generate(() -> 2.0 * j.incrementAndGet() + normal.sample()).limit(20) + ).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + tp += type instanceof ChangeType.TrendChange ? 1 : 0; + } + assertThat(tp, greaterThan(80)); + + tp = 0; + GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); + for (int i = 0; i < 100; i++) { + AtomicInteger j = new AtomicInteger(); + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.concat( + DoubleStream.generate(() -> j.incrementAndGet() + gamma.sample()).limit(20), + DoubleStream.generate(() -> 2.0 * j.incrementAndGet() + gamma.sample()).limit(20) + ).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + tp += type instanceof ChangeType.TrendChange ? 1 : 0; + } + assertThat(tp, greaterThan(80)); + } + + public void testDistributionChangeTestPower() { + NormalDistribution normal1 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0.0, 1.0); + NormalDistribution normal2 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0.0, 10.0); + int tp = 0; + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.concat( + DoubleStream.generate(() -> 10 + normal1.sample()).limit(50), + DoubleStream.generate(() -> 10 + normal2.sample()).limit(50) + ).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + tp += type instanceof ChangeType.DistributionChange ? 1 : 0; + } + assertThat(tp, greaterThan(90)); + } + + public void testMultipleChanges() { + NormalDistribution normal1 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 78.0, 3.0); + NormalDistribution normal2 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 40.0, 6.0); + NormalDistribution normal3 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1.0, 0.3); + int tp = 0; + for (int i = 0; i < 100; i++) { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + DoubleStream.concat( + DoubleStream.concat(DoubleStream.generate(normal1::sample).limit(7), DoubleStream.generate(normal2::sample).limit(6)), + DoubleStream.generate(normal3::sample).limit(23) + ).toArray() + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + tp += type instanceof ChangeType.TrendChange ? 1 : 0; + } + assertThat(tp, greaterThan(90)); + } + + public void testProblemDistributionChange() { + MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues( + null, + new double[] { + 546.3651753325270, + 550.872738079514, + 551.1312487618040, + 550.3323904749380, + 549.2652495378930, + 548.9761274963630, + 549.3433969743010, + 549.0935313531350, + 551.1762550747600, + 551.3772184469220, + 548.6163495094490, + 548.5866591594080, + 546.9364791288570, + 548.1167839989470, + 549.3484016149320, + 550.4242803917040, + 551.2316023050940, + 548.4713993534340, + 546.0254901960780, + 548.4376996805110, + 561.1920529801320, + 557.3930041152260, + 565.8497217068650, + 566.787072243346, + 546.6094890510950, + 530.5905797101450, + 556.7340823970040, + 557.3857677902620, + 543.0754716981130, + 574.3297101449280, + 559.2962962962960, + 549.5202952029520, + 531.7217741935480, + 551.4333333333330, + 557.637168141593, + 545.1880733944950, + 564.6893203883500, + 543.0204081632650, + 571.820809248555, + 541.2589928057550, + 520.4387755102040 } + ); + ChangeType type = new ChangeDetector(bucketValues).detect(0.05); + assertThat(type, instanceOf(ChangeType.DistributionChange.class)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregatorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregatorTests.java index 73131efbbcf4b..5cb66aaa5a58c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregatorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregatorTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.aggs.changepoint; -import org.apache.commons.math3.distribution.GammaDistribution; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.random.RandomGeneratorFactory; import org.apache.logging.log4j.LogManager; @@ -37,10 +36,7 @@ import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.lessThan; -import static org.junit.Assert.assertThat; public class ChangePointAggregatorTests extends AggregatorTestCase { @@ -55,194 +51,6 @@ protected List getSearchPlugins() { private static final String NUMERIC_FIELD_NAME = "value"; private static final String TIME_FIELD_NAME = "timestamp"; - public void testStationaryFalsePositiveRate() throws IOException { - NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); - int fp = 0; - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.generate(() -> 10 + normal.sample()).limit(40).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 1e-4); - fp += test.type() == ChangePointAggregator.Type.STATIONARY ? 0 : 1; - } - assertThat(fp, lessThan(10)); - - fp = 0; - GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.generate(() -> gamma.sample()).limit(40).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 1e-4); - fp += test.type() == ChangePointAggregator.Type.STATIONARY ? 0 : 1; - } - assertThat(fp, lessThan(10)); - } - - public void testSampledDistributionTestFalsePositiveRate() throws IOException { - NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0.0, 1.0); - int fp = 0; - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.generate(() -> 10 + normal.sample()).limit(5000).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 1e-4); - fp += test.type() == ChangePointAggregator.Type.STATIONARY ? 0 : 1; - } - assertThat(fp, lessThan(10)); - } - - public void testNonStationaryFalsePositiveRate() throws IOException { - NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); - int fp = 0; - for (int i = 0; i < 100; i++) { - AtomicInteger j = new AtomicInteger(); - double[] bucketValues = DoubleStream.generate(() -> j.incrementAndGet() + normal.sample()).limit(40).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 1e-4); - fp += test.type() == ChangePointAggregator.Type.NON_STATIONARY ? 0 : 1; - } - assertThat(fp, lessThan(10)); - - fp = 0; - GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); - for (int i = 0; i < 100; i++) { - AtomicInteger j = new AtomicInteger(); - double[] bucketValues = DoubleStream.generate(() -> j.incrementAndGet() + gamma.sample()).limit(40).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 1e-4); - fp += test.type() == ChangePointAggregator.Type.NON_STATIONARY ? 0 : 1; - } - assertThat(fp, lessThan(10)); - } - - public void testStepChangePower() throws IOException { - NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); - int tp = 0; - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.concat( - DoubleStream.generate(() -> normal.sample()).limit(20), - DoubleStream.generate(() -> 10 + normal.sample()).limit(20) - ).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 0.05); - tp += test.type() == ChangePointAggregator.Type.STEP_CHANGE ? 1 : 0; - } - assertThat(tp, greaterThan(80)); - - tp = 0; - GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.concat( - DoubleStream.generate(() -> gamma.sample()).limit(20), - DoubleStream.generate(() -> 10 + gamma.sample()).limit(20) - ).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 0.05); - tp += test.type() == ChangePointAggregator.Type.STEP_CHANGE ? 1 : 0; - } - assertThat(tp, greaterThan(80)); - } - - public void testTrendChangePower() throws IOException { - NormalDistribution normal = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0, 2); - int tp = 0; - for (int i = 0; i < 100; i++) { - AtomicInteger j = new AtomicInteger(); - double[] bucketValues = DoubleStream.concat( - DoubleStream.generate(() -> j.incrementAndGet() + normal.sample()).limit(20), - DoubleStream.generate(() -> 2.0 * j.incrementAndGet() + normal.sample()).limit(20) - ).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 0.05); - tp += test.type() == ChangePointAggregator.Type.TREND_CHANGE ? 1 : 0; - } - assertThat(tp, greaterThan(80)); - - tp = 0; - GammaDistribution gamma = new GammaDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1, 2); - for (int i = 0; i < 100; i++) { - AtomicInteger j = new AtomicInteger(); - double[] bucketValues = DoubleStream.concat( - DoubleStream.generate(() -> j.incrementAndGet() + gamma.sample()).limit(20), - DoubleStream.generate(() -> 2.0 * j.incrementAndGet() + gamma.sample()).limit(20) - ).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 0.05); - tp += test.type() == ChangePointAggregator.Type.TREND_CHANGE ? 1 : 0; - } - assertThat(tp, greaterThan(80)); - } - - public void testDistributionChangeTestPower() throws IOException { - NormalDistribution normal1 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0.0, 1.0); - NormalDistribution normal2 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 0.0, 10.0); - int tp = 0; - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.concat( - DoubleStream.generate(() -> 10 + normal1.sample()).limit(50), - DoubleStream.generate(() -> 10 + normal2.sample()).limit(50) - ).toArray(); - ChangePointAggregator.TestStats test = ChangePointAggregator.testForChange(bucketValues, 0.05); - tp += test.type() == ChangePointAggregator.Type.DISTRIBUTION_CHANGE ? 1 : 0; - } - assertThat(tp, greaterThan(90)); - } - - public void testMultipleChanges() throws IOException { - NormalDistribution normal1 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 78.0, 3.0); - NormalDistribution normal2 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 40.0, 6.0); - NormalDistribution normal3 = new NormalDistribution(RandomGeneratorFactory.createRandomGenerator(Randomness.get()), 1.0, 0.3); - int tp = 0; - for (int i = 0; i < 100; i++) { - double[] bucketValues = DoubleStream.concat( - DoubleStream.concat( - DoubleStream.generate(() -> normal1.sample()).limit(7), - DoubleStream.generate(() -> normal2.sample()).limit(6) - ), - DoubleStream.generate(() -> normal3.sample()).limit(23) - ).toArray(); - ChangePointAggregator.TestStats result = ChangePointAggregator.testForChange(bucketValues, 0.05); - tp += result.type() == ChangePointAggregator.Type.TREND_CHANGE ? 1 : 0; - } - assertThat(tp, greaterThan(90)); - } - - public void testProblemDistributionChange() throws IOException { - double[] bucketValues = new double[] { - 546.3651753325270, - 550.872738079514, - 551.1312487618040, - 550.3323904749380, - 549.2652495378930, - 548.9761274963630, - 549.3433969743010, - 549.0935313531350, - 551.1762550747600, - 551.3772184469220, - 548.6163495094490, - 548.5866591594080, - 546.9364791288570, - 548.1167839989470, - 549.3484016149320, - 550.4242803917040, - 551.2316023050940, - 548.4713993534340, - 546.0254901960780, - 548.4376996805110, - 561.1920529801320, - 557.3930041152260, - 565.8497217068650, - 566.787072243346, - 546.6094890510950, - 530.5905797101450, - 556.7340823970040, - 557.3857677902620, - 543.0754716981130, - 574.3297101449280, - 559.2962962962960, - 549.5202952029520, - 531.7217741935480, - 551.4333333333330, - 557.637168141593, - 545.1880733944950, - 564.6893203883500, - 543.0204081632650, - 571.820809248555, - 541.2589928057550, - 520.4387755102040 }; - ChangePointAggregator.TestStats result = ChangePointAggregator.testForChange(bucketValues, 0.05); - assertThat(result.type(), equalTo(ChangePointAggregator.Type.DISTRIBUTION_CHANGE)); - } - public void testConstant() throws IOException { double[] bucketValues = DoubleStream.generate(() -> 10).limit(100).toArray(); testChangeType( @@ -262,7 +70,6 @@ public void testSlopeUp() throws IOException { // Handle infrequent false positives. assertThat(changeType, instanceOf(ChangeType.TrendChange.class)); } - }); } @@ -600,5 +407,4 @@ private static void writeTestDocs(RandomIndexWriter w, double[] bucketValues) th epoch_timestamp += INTERVAL.estimateMillis(); } } - } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetectorTests.java index fe91aa3e6a600..b21a7c4625e83 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/SpikeAndDipDetectorTests.java @@ -25,14 +25,14 @@ public void testTooLittleData() { Arrays.fill(docCounts, 1); Arrays.fill(values, 1.0); MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues(docCounts, values); - SpikeAndDipDetector detect = new SpikeAndDipDetector(values); - assertThat(detect.at(0.01, bucketValues), instanceOf(ChangeType.Indeterminable.class)); + SpikeAndDipDetector detect = new SpikeAndDipDetector(bucketValues); + assertThat(detect.detect(0.01), instanceOf(ChangeType.Indeterminable.class)); } } public void testSpikeAndDipValues() { double[] values = new double[] { 2.0, 1.0, 3.0, 5.0, 4.0 }; - SpikeAndDipDetector detector = new SpikeAndDipDetector(values); + SpikeAndDipDetector detector = new SpikeAndDipDetector(new MlAggsHelper.DoubleBucketValues(null, values)); assertThat(detector.spikeValue(), equalTo(5.0)); assertThat(detector.dipValue(), equalTo(1.0)); } @@ -133,7 +133,7 @@ public void testExludedValues() { Arrays.sort(expectedSpikeKDEValues); Arrays.sort(expectedDipKDEValues); - SpikeAndDipDetector detector = new SpikeAndDipDetector(values); + SpikeAndDipDetector detector = new SpikeAndDipDetector(new MlAggsHelper.DoubleBucketValues(null, values)); assertThat(detector.spikeValue(), equalTo(10.0)); assertThat(detector.dipValue(), equalTo(-2.0)); @@ -150,9 +150,9 @@ public void testDetection() { double[] values = new double[] { 0.1, 3.1, 1.2, 1.7, 0.9, 2.3, -0.8, 3.2, 1.2, 1.3, 1.1, 1.0, 8.5, 0.5, 2.6, 0.7 }; MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues(docCounts, values); - SpikeAndDipDetector detect = new SpikeAndDipDetector(values); + SpikeAndDipDetector detect = new SpikeAndDipDetector(bucketValues); - ChangeType change = detect.at(0.05, bucketValues); + ChangeType change = detect.detect(0.05); assertThat(change, instanceOf(ChangeType.Spike.class)); assertThat(change.pValue(), closeTo(3.0465e-12, 1e-15)); @@ -162,9 +162,9 @@ public void testDetection() { double[] values = new double[] { 0.1, 3.1, 1.2, 1.7, 0.9, 2.3, -4.2, 3.2, 1.2, 1.3, 1.1, 1.0, 3.5, 0.5, 2.6, 0.7 }; MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues(docCounts, values); - SpikeAndDipDetector detect = new SpikeAndDipDetector(values); + SpikeAndDipDetector detect = new SpikeAndDipDetector(bucketValues); - ChangeType change = detect.at(0.05, bucketValues); + ChangeType change = detect.detect(0.05); assertThat(change, instanceOf(ChangeType.Dip.class)); assertThat(change.pValue(), closeTo(1.2589e-08, 1e-11)); @@ -177,9 +177,9 @@ public void testMissingBuckets() { int[] buckets = new int[] { 0, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 15, 17, 18, 19, 20 }; MlAggsHelper.DoubleBucketValues bucketValues = new MlAggsHelper.DoubleBucketValues(docCounts, values, buckets); - SpikeAndDipDetector detect = new SpikeAndDipDetector(values); + SpikeAndDipDetector detect = new SpikeAndDipDetector(bucketValues); - ChangeType change = detect.at(0.01, bucketValues); + ChangeType change = detect.detect(0.01); assertThat(change, instanceOf(ChangeType.Spike.class)); assertThat(change.changePoint(), equalTo(10)); diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java index c3d502e561bd7..3a9069dee064d 100644 --- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java +++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/collector/cluster/ClusterStatsMonitoringDocTests.java @@ -590,7 +590,8 @@ public void testToXContent() throws IOException { "total": 0, "queries": {}, "rescorers": {}, - "sections": {} + "sections": {}, + "retrievers": {} }, "dense_vector": { "value_count": 0 diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java new file mode 100644 index 0000000000000..4eaea9a596361 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverTelemetryIT.java @@ -0,0 +1,194 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.rrf; + +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest; +import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse; +import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats; +import org.elasticsearch.client.Request; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0) +public class RRFRetrieverTelemetryIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "test_index"; + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + @Override + protected Collection> nodePlugins() { + return List.of(RRFRankPlugin.class, XPackPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put("xpack.license.self_generated.type", "trial") + .build(); + } + + @Before + public void setup() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("vector") + .field("type", "dense_vector") + .field("dims", 1) + .field("index", true) + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", "hnsw") + .endObject() + .endObject() + .startObject("text") + .field("type", "text") + .endObject() + .startObject("integer") + .field("type", "integer") + .endObject() + .startObject("topic") + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + + assertAcked(prepareCreate(INDEX_NAME).setMapping(builder)); + ensureGreen(INDEX_NAME); + } + + private void performSearch(SearchSourceBuilder source) throws IOException { + Request request = new Request("GET", INDEX_NAME + "/_search"); + request.setJsonEntity(Strings.toString(source)); + getRestClient().performRequest(request); + } + + public void testTelemetryForRRFRetriever() throws IOException { + + if (false == isRetrieverTelemetryEnabled()) { + return; + } + + // search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers` + { + performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null))); + } + + // search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under + // `queries` + { + performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2)))); + } + + // search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under + // `queries` + { + performSearch( + new SearchSourceBuilder().retriever( + new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null)) + ) + ); + } + + // search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term" + // under `queries` + { + performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo")))); + } + + // search#5 - this will record 1 entry for "retriever" in `sections`, and 1 for "rrf" under `retrievers`, as well as + // 1 for "knn" and 1 for "standard" under `retrievers`, and eventually 1 for "match" under `queries` + { + performSearch( + new SearchSourceBuilder().retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource( + new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null), + null + ), + new CompoundRetrieverBuilder.RetrieverSource( + new StandardRetrieverBuilder(QueryBuilders.matchQuery("text", "foo")), + null + ) + ), + 10, + 10 + ) + ) + ); + } + + // search#6 - this will record 1 entry for "knn" in `sections` + { + performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null)))); + } + + // search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries` + { + performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())); + } + + // cluster stats + { + SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats(); + assertEquals(7, stats.getTotalSearchCount()); + + assertThat(stats.getSectionsUsage().size(), equalTo(3)); + assertThat(stats.getSectionsUsage().get("retriever"), equalTo(5L)); + assertThat(stats.getSectionsUsage().get("query"), equalTo(1L)); + assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L)); + + assertThat(stats.getRetrieversUsage().size(), equalTo(3)); + assertThat(stats.getRetrieversUsage().get("standard"), equalTo(4L)); + assertThat(stats.getRetrieversUsage().get("knn"), equalTo(2L)); + assertThat(stats.getRetrieversUsage().get("rrf"), equalTo(1L)); + + assertThat(stats.getQueryUsage().size(), equalTo(5)); + assertThat(stats.getQueryUsage().get("range"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("term"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("match"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L)); + assertThat(stats.getQueryUsage().get("knn"), equalTo(1L)); + } + } + + private boolean isRetrieverTelemetryEnabled() throws IOException { + NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities( + new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats") + ).actionGet(); + return res != null && res.isSupported().orElse(false); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 5f19e361d857d..12c43a2f169f8 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -68,6 +68,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder