diff --git a/CHANGELOG.md b/CHANGELOG.md index 69fb6b42fe29f..15dedbde22d41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add a new node role 'search' which is dedicated to provide search capability ([#4689](https://github.com/opensearch-project/OpenSearch/pull/4689)) - Introduce experimental searchable snapshot API ([#4680](https://github.com/opensearch-project/OpenSearch/pull/4680)) - Recommissioning of zone. REST layer support. ([#4624](https://github.com/opensearch-project/OpenSearch/pull/4604)) +- Added in-flight cancellation of SearchShardTask based on resource consumption ([#4565](https://github.com/opensearch-project/OpenSearch/pull/4565)) + ### Dependencies - Bumps `log4j-core` from 2.18.0 to 2.19.0 - Bumps `reactor-netty-http` from 1.0.18 to 1.0.23 diff --git a/server/src/main/java/org/opensearch/cluster/ClusterModule.java b/server/src/main/java/org/opensearch/cluster/ClusterModule.java index 3ac3a08d4848a..6279447485348 100644 --- a/server/src/main/java/org/opensearch/cluster/ClusterModule.java +++ b/server/src/main/java/org/opensearch/cluster/ClusterModule.java @@ -96,7 +96,6 @@ import org.opensearch.script.ScriptMetadata; import org.opensearch.snapshots.SnapshotsInfoService; import org.opensearch.tasks.Task; -import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.tasks.TaskResultsService; import java.util.ArrayList; @@ -420,7 +419,6 @@ protected void configure() { bind(NodeMappingRefreshAction.class).asEagerSingleton(); bind(MappingUpdatedAction.class).asEagerSingleton(); bind(TaskResultsService.class).asEagerSingleton(); - bind(TaskResourceTrackingService.class).asEagerSingleton(); bind(AllocationDeciders.class).toInstance(allocationDeciders); bind(ShardsAllocator.class).toInstance(shardsAllocator); } diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java index 54579031aac08..320cb5457b21c 100644 --- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java @@ -41,6 +41,9 @@ import org.opensearch.index.ShardIndexingPressureMemoryManager; import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.ShardIndexingPressureStore; +import org.opensearch.search.backpressure.settings.NodeDuressSettings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.settings.SearchShardTaskSettings; import org.opensearch.tasks.TaskManager; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.watcher.ResourceWatcherService; @@ -581,7 +584,22 @@ public void apply(Settings value, Settings current, Settings previous) { ShardIndexingPressureMemoryManager.MAX_OUTSTANDING_REQUESTS, IndexingPressure.MAX_INDEXING_BYTES, TaskResourceTrackingService.TASK_RESOURCE_TRACKING_ENABLED, - TaskManager.TASK_RESOURCE_CONSUMERS_ENABLED + TaskManager.TASK_RESOURCE_CONSUMERS_ENABLED, + + // Settings related to search backpressure + SearchBackpressureSettings.SETTING_ENABLED, + SearchBackpressureSettings.SETTING_ENFORCED, + SearchBackpressureSettings.SETTING_CANCELLATION_RATIO, + SearchBackpressureSettings.SETTING_CANCELLATION_RATE, + SearchBackpressureSettings.SETTING_CANCELLATION_BURST, + NodeDuressSettings.SETTING_NUM_SUCCESSIVE_BREACHES, + NodeDuressSettings.SETTING_CPU_THRESHOLD, + NodeDuressSettings.SETTING_HEAP_THRESHOLD, + SearchShardTaskSettings.SETTING_TOTAL_HEAP_THRESHOLD, + SearchShardTaskSettings.SETTING_HEAP_THRESHOLD, + SearchShardTaskSettings.SETTING_HEAP_VARIANCE_THRESHOLD, + SearchShardTaskSettings.SETTING_CPU_TIME_THRESHOLD, + SearchShardTaskSettings.SETTING_ELAPSED_TIME_THRESHOLD ) ) ); diff --git a/server/src/main/java/org/opensearch/common/util/MovingAverage.java b/server/src/main/java/org/opensearch/common/util/MovingAverage.java new file mode 100644 index 0000000000000..650ba62ecd8c8 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/MovingAverage.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +/** + * MovingAverage is used to calculate the moving average of last 'n' observations. + * + * @opensearch.internal + */ +public class MovingAverage { + private final int windowSize; + private final long[] observations; + + private long count = 0; + private long sum = 0; + private double average = 0; + + public MovingAverage(int windowSize) { + if (windowSize <= 0) { + throw new IllegalArgumentException("window size must be greater than zero"); + } + + this.windowSize = windowSize; + this.observations = new long[windowSize]; + } + + /** + * Records a new observation and evicts the n-th last observation. + */ + public synchronized double record(long value) { + long delta = value - observations[(int) (count % observations.length)]; + observations[(int) (count % observations.length)] = value; + + count++; + sum += delta; + average = (double) sum / Math.min(count, observations.length); + return average; + } + + public double getAverage() { + return average; + } + + public long getCount() { + return count; + } + + public boolean isReady() { + return count >= windowSize; + } +} diff --git a/server/src/main/java/org/opensearch/common/util/Streak.java b/server/src/main/java/org/opensearch/common/util/Streak.java new file mode 100644 index 0000000000000..5f6ad3021659e --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/Streak.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Streak is a data structure that keeps track of the number of successive successful events. + * + * @opensearch.internal + */ +public class Streak { + private final AtomicInteger successiveSuccessfulEvents = new AtomicInteger(); + + public int record(boolean isSuccessful) { + if (isSuccessful) { + return successiveSuccessfulEvents.incrementAndGet(); + } else { + successiveSuccessfulEvents.set(0); + return 0; + } + } + + public int length() { + return successiveSuccessfulEvents.get(); + } +} diff --git a/server/src/main/java/org/opensearch/common/util/TokenBucket.java b/server/src/main/java/org/opensearch/common/util/TokenBucket.java new file mode 100644 index 0000000000000..e47f152d71363 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/TokenBucket.java @@ -0,0 +1,124 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; + +/** + * TokenBucket is used to limit the number of operations at a constant rate while allowing for short bursts. + * + * @opensearch.internal + */ +public class TokenBucket { + /** + * Defines a monotonically increasing counter. + * + * Usage examples: + * 1. clock = System::nanoTime can be used to perform rate-limiting per unit time + * 2. clock = AtomicLong::get can be used to perform rate-limiting per unit number of operations + */ + private final LongSupplier clock; + + /** + * Defines the number of tokens added to the bucket per clock cycle. + */ + private final double rate; + + /** + * Defines the capacity and the maximum number of operations that can be performed per clock cycle before + * the bucket runs out of tokens. + */ + private final double burst; + + /** + * Defines the current state of the token bucket. + */ + private final AtomicReference state; + + public TokenBucket(LongSupplier clock, double rate, double burst) { + this(clock, rate, burst, burst); + } + + public TokenBucket(LongSupplier clock, double rate, double burst, double initialTokens) { + if (rate <= 0.0) { + throw new IllegalArgumentException("rate must be greater than zero"); + } + + if (burst <= 0.0) { + throw new IllegalArgumentException("burst must be greater than zero"); + } + + this.clock = clock; + this.rate = rate; + this.burst = burst; + this.state = new AtomicReference<>(new State(Math.min(initialTokens, burst), clock.getAsLong())); + } + + /** + * If there are enough tokens in the bucket, it requests/deducts 'n' tokens and returns true. + * Otherwise, returns false and leaves the bucket untouched. + */ + public boolean request(double n) { + if (n <= 0) { + throw new IllegalArgumentException("requested tokens must be greater than zero"); + } + + // Refill tokens + State currentState, updatedState; + do { + currentState = state.get(); + long now = clock.getAsLong(); + double incr = (now - currentState.lastRefilledAt) * rate; + updatedState = new State(Math.min(currentState.tokens + incr, burst), now); + } while (state.compareAndSet(currentState, updatedState) == false); + + // Deduct tokens + do { + currentState = state.get(); + if (currentState.tokens < n) { + return false; + } + updatedState = new State(currentState.tokens - n, currentState.lastRefilledAt); + } while (state.compareAndSet(currentState, updatedState) == false); + + return true; + } + + public boolean request() { + return request(1.0); + } + + /** + * Represents an immutable token bucket state. + */ + private static class State { + final double tokens; + final double lastRefilledAt; + + public State(double tokens, double lastRefilledAt) { + this.tokens = tokens; + this.lastRefilledAt = lastRefilledAt; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + State state = (State) o; + return Double.compare(state.tokens, tokens) == 0 && Double.compare(state.lastRefilledAt, lastRefilledAt) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(tokens, lastRefilledAt); + } + } +} diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 504550378e14f..54696b1f6b118 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -43,6 +43,8 @@ import org.opensearch.indices.replication.SegmentReplicationSourceFactory; import org.opensearch.indices.replication.SegmentReplicationTargetService; import org.opensearch.indices.replication.SegmentReplicationSourceService; +import org.opensearch.search.backpressure.SearchBackpressureService; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; import org.opensearch.tasks.TaskResourceTrackingService; import org.opensearch.threadpool.RunnableTaskExecutionListener; import org.opensearch.index.store.RemoteSegmentStoreDirectoryFactory; @@ -796,6 +798,23 @@ protected Node( // development. Then we can deprecate Getter and Setter for IndexingPressureService in ClusterService (#478). clusterService.setIndexingPressureService(indexingPressureService); + final TaskResourceTrackingService taskResourceTrackingService = new TaskResourceTrackingService( + settings, + clusterService.getClusterSettings(), + threadPool + ); + + final SearchBackpressureSettings searchBackpressureSettings = new SearchBackpressureSettings( + settings, + clusterService.getClusterSettings() + ); + + final SearchBackpressureService searchBackpressureService = new SearchBackpressureService( + searchBackpressureSettings, + taskResourceTrackingService, + threadPool + ); + final RecoverySettings recoverySettings = new RecoverySettings(settings, settingsModule.getClusterSettings()); RepositoriesModule repositoriesModule = new RepositoriesModule( this.environment, @@ -883,7 +902,8 @@ protected Node( responseCollectorService, searchTransportService, indexingPressureService, - searchModule.getValuesSourceRegistry().getUsageService() + searchModule.getValuesSourceRegistry().getUsageService(), + searchBackpressureService ); final SearchService searchService = newSearchService( @@ -941,6 +961,8 @@ protected Node( b.bind(AnalysisRegistry.class).toInstance(analysisModule.getAnalysisRegistry()); b.bind(IngestService.class).toInstance(ingestService); b.bind(IndexingPressureService.class).toInstance(indexingPressureService); + b.bind(TaskResourceTrackingService.class).toInstance(taskResourceTrackingService); + b.bind(SearchBackpressureService.class).toInstance(searchBackpressureService); b.bind(UsageService.class).toInstance(usageService); b.bind(AggregationUsageService.class).toInstance(searchModule.getValuesSourceRegistry().getUsageService()); b.bind(NamedWriteableRegistry.class).toInstance(namedWriteableRegistry); @@ -1104,6 +1126,7 @@ public Node start() throws NodeValidationException { injector.getInstance(SearchService.class).start(); injector.getInstance(FsHealthService.class).start(); nodeService.getMonitorService().start(); + nodeService.getSearchBackpressureService().start(); final ClusterService clusterService = injector.getInstance(ClusterService.class); @@ -1259,6 +1282,7 @@ private Node stop() { injector.getInstance(NodeConnectionsService.class).stop(); injector.getInstance(FsHealthService.class).stop(); nodeService.getMonitorService().stop(); + nodeService.getSearchBackpressureService().stop(); injector.getInstance(GatewayService.class).stop(); injector.getInstance(SearchService.class).stop(); injector.getInstance(TransportService.class).stop(); @@ -1318,6 +1342,7 @@ public synchronized void close() throws IOException { toClose.add(injector.getInstance(Discovery.class)); toClose.add(() -> stopWatch.stop().start("monitor")); toClose.add(nodeService.getMonitorService()); + toClose.add(nodeService.getSearchBackpressureService()); toClose.add(() -> stopWatch.stop().start("fsHealth")); toClose.add(injector.getInstance(FsHealthService.class)); toClose.add(() -> stopWatch.stop().start("gateway")); diff --git a/server/src/main/java/org/opensearch/node/NodeService.java b/server/src/main/java/org/opensearch/node/NodeService.java index ab98b47c7287b..f4e037db66ca2 100644 --- a/server/src/main/java/org/opensearch/node/NodeService.java +++ b/server/src/main/java/org/opensearch/node/NodeService.java @@ -53,6 +53,7 @@ import org.opensearch.plugins.PluginsService; import org.opensearch.script.ScriptService; import org.opensearch.search.aggregations.support.AggregationUsageService; +import org.opensearch.search.backpressure.SearchBackpressureService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -81,6 +82,7 @@ public class NodeService implements Closeable { private final SearchTransportService searchTransportService; private final IndexingPressureService indexingPressureService; private final AggregationUsageService aggregationUsageService; + private final SearchBackpressureService searchBackpressureService; private final Discovery discovery; @@ -101,7 +103,8 @@ public class NodeService implements Closeable { ResponseCollectorService responseCollectorService, SearchTransportService searchTransportService, IndexingPressureService indexingPressureService, - AggregationUsageService aggregationUsageService + AggregationUsageService aggregationUsageService, + SearchBackpressureService searchBackpressureService ) { this.settings = settings; this.threadPool = threadPool; @@ -119,6 +122,7 @@ public class NodeService implements Closeable { this.searchTransportService = searchTransportService; this.indexingPressureService = indexingPressureService; this.aggregationUsageService = aggregationUsageService; + this.searchBackpressureService = searchBackpressureService; clusterService.addStateApplier(ingestService); } @@ -203,6 +207,10 @@ public MonitorService getMonitorService() { return monitorService; } + public SearchBackpressureService getSearchBackpressureService() { + return searchBackpressureService; + } + @Override public void close() throws IOException { IOUtils.close(indicesService); diff --git a/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java new file mode 100644 index 0000000000000..885846a177d60 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureService.java @@ -0,0 +1,311 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.component.AbstractLifecycleComponent; +import org.opensearch.common.util.TokenBucket; +import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.monitor.process.ProcessProbe; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.trackers.NodeDuressTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.tasks.TaskResourceTrackingService.TaskCompletionListener; +import org.opensearch.threadpool.Scheduler; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; +import java.util.stream.Collectors; + +/** + * SearchBackpressureService is responsible for monitoring and cancelling in-flight search tasks if they are + * breaching resource usage limits when the node is in duress. + * + * @opensearch.internal + */ +public class SearchBackpressureService extends AbstractLifecycleComponent + implements + TaskCompletionListener, + SearchBackpressureSettings.Listener { + private static final Logger logger = LogManager.getLogger(SearchBackpressureService.class); + + private volatile Scheduler.Cancellable scheduledFuture; + + private final SearchBackpressureSettings settings; + private final TaskResourceTrackingService taskResourceTrackingService; + private final ThreadPool threadPool; + private final LongSupplier timeNanosSupplier; + + private final List nodeDuressTrackers; + private final List taskResourceUsageTrackers; + + private final AtomicReference taskCancellationRateLimiter = new AtomicReference<>(); + private final AtomicReference taskCancellationRatioLimiter = new AtomicReference<>(); + + // Currently, only the state of SearchShardTask is being tracked. + // This can be generalized to Map once we start supporting cancellation of SearchTasks as well. + private final SearchBackpressureState state = new SearchBackpressureState(); + + public SearchBackpressureService( + SearchBackpressureSettings settings, + TaskResourceTrackingService taskResourceTrackingService, + ThreadPool threadPool + ) { + this( + settings, + taskResourceTrackingService, + threadPool, + System::nanoTime, + List.of( + new NodeDuressTracker( + () -> ProcessProbe.getInstance().getProcessCpuPercent() / 100.0 >= settings.getNodeDuressSettings().getCpuThreshold() + ), + new NodeDuressTracker( + () -> JvmStats.jvmStats().getMem().getHeapUsedPercent() / 100.0 >= settings.getNodeDuressSettings().getHeapThreshold() + ) + ), + Collections.emptyList() + ); + } + + public SearchBackpressureService( + SearchBackpressureSettings settings, + TaskResourceTrackingService taskResourceTrackingService, + ThreadPool threadPool, + LongSupplier timeNanosSupplier, + List nodeDuressTrackers, + List taskResourceUsageTrackers + ) { + this.settings = settings; + this.settings.setListener(this); + this.taskResourceTrackingService = taskResourceTrackingService; + this.taskResourceTrackingService.addTaskCompletionListener(this); + this.threadPool = threadPool; + this.timeNanosSupplier = timeNanosSupplier; + this.nodeDuressTrackers = nodeDuressTrackers; + this.taskResourceUsageTrackers = taskResourceUsageTrackers; + + this.taskCancellationRateLimiter.set( + new TokenBucket(timeNanosSupplier, getSettings().getCancellationRateNanos(), getSettings().getCancellationBurst()) + ); + + this.taskCancellationRatioLimiter.set( + new TokenBucket(state::getCompletionCount, getSettings().getCancellationRatio(), getSettings().getCancellationBurst()) + ); + } + + void doRun() { + if (getSettings().isEnabled() == false) { + return; + } + + if (isNodeInDuress() == false) { + return; + } + + // We are only targeting in-flight cancellation of SearchShardTask for now. + List searchShardTasks = getSearchShardTasks(); + + // Force-refresh usage stats of these tasks before making a cancellation decision. + taskResourceTrackingService.refreshResourceStats(searchShardTasks.toArray(new Task[0])); + + // Skip cancellation if the increase in heap usage is not due to search requests. + if (isHeapUsageDominatedBySearch(searchShardTasks) == false) { + return; + } + + for (TaskCancellation taskCancellation : getTaskCancellations(searchShardTasks)) { + logger.debug( + "cancelling task [{}] due to high resource consumption [{}]", + taskCancellation.getTask().getId(), + taskCancellation.getReasonString() + ); + + if (getSettings().isEnforced() == false) { + continue; + } + + // Independently remove tokens from both token buckets. + boolean rateLimitReached = taskCancellationRateLimiter.get().request() == false; + boolean ratioLimitReached = taskCancellationRatioLimiter.get().request() == false; + + // Stop cancelling tasks if there are no tokens in either of the two token buckets. + if (rateLimitReached && ratioLimitReached) { + logger.debug("task cancellation limit reached"); + state.incrementLimitReachedCount(); + break; + } + + taskCancellation.cancel(); + state.incrementCancellationCount(); + } + } + + /** + * Returns true if the node is in duress consecutively for the past 'n' observations. + */ + boolean isNodeInDuress() { + boolean isNodeInDuress = false; + int numSuccessiveBreaches = getSettings().getNodeDuressSettings().getNumSuccessiveBreaches(); + + for (NodeDuressTracker tracker : nodeDuressTrackers) { + if (tracker.check() >= numSuccessiveBreaches) { + isNodeInDuress = true; // not breaking the loop so that each tracker's streak gets updated. + } + } + + return isNodeInDuress; + } + + /** + * Returns true if the increase in heap usage is due to search requests. + */ + boolean isHeapUsageDominatedBySearch(List searchShardTasks) { + long runningTasksHeapUsage = searchShardTasks.stream().mapToLong(task -> task.getTotalResourceStats().getMemoryInBytes()).sum(); + long searchTasksHeapThreshold = getSettings().getSearchShardTaskSettings().getTotalHeapThresholdBytes(); + if (runningTasksHeapUsage < searchTasksHeapThreshold) { + logger.debug("heap usage not dominated by search requests [{}/{}]", runningTasksHeapUsage, searchTasksHeapThreshold); + return false; + } + + return true; + } + + /** + * Filters and returns the list of currently running SearchShardTasks. + */ + List getSearchShardTasks() { + return taskResourceTrackingService.getResourceAwareTasks() + .values() + .stream() + .filter(task -> task instanceof SearchShardTask) + .map(task -> (CancellableTask) task) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Returns a TaskCancellation wrapper containing the list of reasons (possibly zero), along with an overall + * cancellation score for the given task. Cancelling a task with a higher score has better chance of recovering the + * node from duress. + */ + TaskCancellation getTaskCancellation(CancellableTask task) { + List reasons = new ArrayList<>(); + List callbacks = new ArrayList<>(); + + for (TaskResourceUsageTracker tracker : taskResourceUsageTrackers) { + Optional reason = tracker.cancellationReason(task); + if (reason.isPresent()) { + reasons.add(reason.get()); + callbacks.add(tracker::incrementCancellations); + } + } + + return new TaskCancellation(task, reasons, callbacks); + } + + /** + * Returns a list of TaskCancellations sorted by descending order of their cancellation scores. + */ + List getTaskCancellations(List tasks) { + return tasks.stream() + .map(this::getTaskCancellation) + .filter(TaskCancellation::isEligibleForCancellation) + .sorted(Comparator.reverseOrder()) + .collect(Collectors.toUnmodifiableList()); + } + + SearchBackpressureSettings getSettings() { + return settings; + } + + SearchBackpressureState getState() { + return state; + } + + @Override + public void onTaskCompleted(Task task) { + if (getSettings().isEnabled() == false) { + return; + } + + if (task instanceof SearchShardTask == false) { + return; + } + + SearchShardTask searchShardTask = (SearchShardTask) task; + if (searchShardTask.isCancelled() == false) { + state.incrementCompletionCount(); + } + + List exceptions = new ArrayList<>(); + for (TaskResourceUsageTracker tracker : taskResourceUsageTrackers) { + try { + tracker.update(searchShardTask); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); + } + + @Override + public void onCancellationRatioChanged() { + taskCancellationRatioLimiter.set( + new TokenBucket(state::getCompletionCount, getSettings().getCancellationRatio(), getSettings().getCancellationBurst()) + ); + } + + @Override + public void onCancellationRateChanged() { + taskCancellationRateLimiter.set( + new TokenBucket(timeNanosSupplier, getSettings().getCancellationRateNanos(), getSettings().getCancellationBurst()) + ); + } + + @Override + public void onCancellationBurstChanged() { + onCancellationRatioChanged(); + onCancellationRateChanged(); + } + + @Override + protected void doStart() { + scheduledFuture = threadPool.scheduleWithFixedDelay(() -> { + try { + doRun(); + } catch (Exception e) { + logger.debug("failure in search search backpressure", e); + } + }, getSettings().getInterval(), ThreadPool.Names.GENERIC); + } + + @Override + protected void doStop() { + if (scheduledFuture != null) { + scheduledFuture.cancel(); + } + } + + @Override + protected void doClose() throws IOException {} +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureState.java b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureState.java new file mode 100644 index 0000000000000..a62231ec29ede --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/SearchBackpressureState.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * Tracks the current state of task completions and cancellations. + * + * @opensearch.internal + */ +public class SearchBackpressureState { + /** + * The number of successful task completions. + */ + private final AtomicLong completionCount = new AtomicLong(); + + /** + * The number of task cancellations due to limit breaches. + */ + private final AtomicLong cancellationCount = new AtomicLong(); + + /** + * The number of times task cancellation limit was reached. + */ + private final AtomicLong limitReachedCount = new AtomicLong(); + + public long getCompletionCount() { + return completionCount.get(); + } + + long incrementCompletionCount() { + return completionCount.incrementAndGet(); + } + + public long getCancellationCount() { + return cancellationCount.get(); + } + + long incrementCancellationCount() { + return cancellationCount.incrementAndGet(); + } + + public long getLimitReachedCount() { + return limitReachedCount.get(); + } + + long incrementLimitReachedCount() { + return limitReachedCount.incrementAndGet(); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/package-info.java new file mode 100644 index 0000000000000..36d216993b2fc --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains classes responsible for search backpressure. + */ +package org.opensearch.search.backpressure; diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/NodeDuressSettings.java b/server/src/main/java/org/opensearch/search/backpressure/settings/NodeDuressSettings.java new file mode 100644 index 0000000000000..09c1e4fcef46c --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/NodeDuressSettings.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; + +/** + * Defines the settings for a node to be considered in duress. + * + * @opensearch.internal + */ +public class NodeDuressSettings { + private static class Defaults { + private static final int NUM_SUCCESSIVE_BREACHES = 3; + private static final double CPU_THRESHOLD = 0.9; + private static final double HEAP_THRESHOLD = 0.7; + } + + /** + * Defines the number of successive limit breaches after the node is marked "in duress". + */ + private volatile int numSuccessiveBreaches; + public static final Setting SETTING_NUM_SUCCESSIVE_BREACHES = Setting.intSetting( + "search_backpressure.node_duress.num_successive_breaches", + Defaults.NUM_SUCCESSIVE_BREACHES, + 1, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the CPU usage threshold (in percentage) for a node to be considered "in duress". + */ + private volatile double cpuThreshold; + public static final Setting SETTING_CPU_THRESHOLD = Setting.doubleSetting( + "search_backpressure.node_duress.cpu_threshold", + Defaults.CPU_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the heap usage threshold (in percentage) for a node to be considered "in duress". + */ + private volatile double heapThreshold; + public static final Setting SETTING_HEAP_THRESHOLD = Setting.doubleSetting( + "search_backpressure.node_duress.heap_threshold", + Defaults.HEAP_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + public NodeDuressSettings(Settings settings, ClusterSettings clusterSettings) { + numSuccessiveBreaches = SETTING_NUM_SUCCESSIVE_BREACHES.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_NUM_SUCCESSIVE_BREACHES, this::setNumSuccessiveBreaches); + + cpuThreshold = SETTING_CPU_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CPU_THRESHOLD, this::setCpuThreshold); + + heapThreshold = SETTING_HEAP_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_HEAP_THRESHOLD, this::setHeapThreshold); + } + + public int getNumSuccessiveBreaches() { + return numSuccessiveBreaches; + } + + private void setNumSuccessiveBreaches(int numSuccessiveBreaches) { + this.numSuccessiveBreaches = numSuccessiveBreaches; + } + + public double getCpuThreshold() { + return cpuThreshold; + } + + private void setCpuThreshold(double cpuThreshold) { + this.cpuThreshold = cpuThreshold; + } + + public double getHeapThreshold() { + return heapThreshold; + } + + private void setHeapThreshold(double heapThreshold) { + this.heapThreshold = heapThreshold; + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureSettings.java b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureSettings.java new file mode 100644 index 0000000000000..4834808d768f1 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchBackpressureSettings.java @@ -0,0 +1,213 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +import org.apache.lucene.util.SetOnce; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; + +import java.util.concurrent.TimeUnit; + +/** + * Settings related to search backpressure and cancellation of in-flight requests. + * + * @opensearch.internal + */ +public class SearchBackpressureSettings { + private static class Defaults { + private static final long INTERVAL = 1000; + + private static final boolean ENABLED = true; + private static final boolean ENFORCED = false; + + private static final double CANCELLATION_RATIO = 0.1; + private static final double CANCELLATION_RATE = 0.003; + private static final double CANCELLATION_BURST = 10.0; + } + + /** + * Defines the interval (in millis) at which the SearchBackpressureService monitors and cancels tasks. + */ + private final TimeValue interval; + public static final Setting SETTING_INTERVAL = Setting.longSetting( + "search_backpressure.interval", + Defaults.INTERVAL, + 1, + Setting.Property.NodeScope + ); + + /** + * Defines whether search backpressure is enabled or not. + */ + private volatile boolean enabled; + public static final Setting SETTING_ENABLED = Setting.boolSetting( + "search_backpressure.enabled", + Defaults.ENABLED, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines whether in-flight cancellation of tasks is enabled or not. + */ + private volatile boolean enforced; + public static final Setting SETTING_ENFORCED = Setting.boolSetting( + "search_backpressure.enforced", + Defaults.ENFORCED, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the percentage of tasks to cancel relative to the number of successful task completions. + * In other words, it is the number of tokens added to the bucket on each successful task completion. + */ + private volatile double cancellationRatio; + public static final Setting SETTING_CANCELLATION_RATIO = Setting.doubleSetting( + "search_backpressure.cancellation_ratio", + Defaults.CANCELLATION_RATIO, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the number of tasks to cancel per unit time (in millis). + * In other words, it is the number of tokens added to the bucket each millisecond. + */ + private volatile double cancellationRate; + public static final Setting SETTING_CANCELLATION_RATE = Setting.doubleSetting( + "search_backpressure.cancellation_rate", + Defaults.CANCELLATION_RATE, + 0.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the maximum number of tasks that can be cancelled before being rate-limited. + */ + private volatile double cancellationBurst; + public static final Setting SETTING_CANCELLATION_BURST = Setting.doubleSetting( + "search_backpressure.cancellation_burst", + Defaults.CANCELLATION_BURST, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Callback listeners. + */ + public interface Listener { + void onCancellationRatioChanged(); + + void onCancellationRateChanged(); + + void onCancellationBurstChanged(); + } + + private final SetOnce listener = new SetOnce<>(); + private final NodeDuressSettings nodeDuressSettings; + private final SearchShardTaskSettings searchShardTaskSettings; + + public SearchBackpressureSettings(Settings settings, ClusterSettings clusterSettings) { + this.nodeDuressSettings = new NodeDuressSettings(settings, clusterSettings); + this.searchShardTaskSettings = new SearchShardTaskSettings(settings, clusterSettings); + + interval = new TimeValue(SETTING_INTERVAL.get(settings)); + + enabled = SETTING_ENABLED.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_ENABLED, this::setEnabled); + + enforced = SETTING_ENFORCED.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_ENFORCED, this::setEnforced); + + cancellationRatio = SETTING_CANCELLATION_RATIO.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CANCELLATION_RATIO, this::setCancellationRatio); + + cancellationRate = SETTING_CANCELLATION_RATE.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CANCELLATION_RATE, this::setCancellationRate); + + cancellationBurst = SETTING_CANCELLATION_BURST.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CANCELLATION_BURST, this::setCancellationBurst); + } + + public void setListener(Listener listener) { + this.listener.set(listener); + } + + public NodeDuressSettings getNodeDuressSettings() { + return nodeDuressSettings; + } + + public SearchShardTaskSettings getSearchShardTaskSettings() { + return searchShardTaskSettings; + } + + public TimeValue getInterval() { + return interval; + } + + public boolean isEnabled() { + return enabled; + } + + private void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public boolean isEnforced() { + return enforced; + } + + private void setEnforced(boolean enforced) { + this.enforced = enforced; + } + + public double getCancellationRatio() { + return cancellationRatio; + } + + private void setCancellationRatio(double cancellationRatio) { + this.cancellationRatio = cancellationRatio; + if (listener.get() != null) { + listener.get().onCancellationRatioChanged(); + } + } + + public double getCancellationRate() { + return cancellationRate; + } + + public double getCancellationRateNanos() { + return getCancellationRate() / TimeUnit.MILLISECONDS.toNanos(1); // rate per nanoseconds + } + + private void setCancellationRate(double cancellationRate) { + this.cancellationRate = cancellationRate; + if (listener.get() != null) { + listener.get().onCancellationRateChanged(); + } + } + + public double getCancellationBurst() { + return cancellationBurst; + } + + private void setCancellationBurst(double cancellationBurst) { + this.cancellationBurst = cancellationBurst; + if (listener.get() != null) { + listener.get().onCancellationBurstChanged(); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/SearchShardTaskSettings.java b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchShardTaskSettings.java new file mode 100644 index 0000000000000..1126dad78f554 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/SearchShardTaskSettings.java @@ -0,0 +1,160 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.settings; + +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.monitor.jvm.JvmStats; + +/** + * Defines the settings related to the cancellation of SearchShardTasks. + * + * @opensearch.internal + */ +public class SearchShardTaskSettings { + private static final long HEAP_SIZE_BYTES = JvmStats.jvmStats().getMem().getHeapMax().getBytes(); + + private static class Defaults { + private static final double TOTAL_HEAP_THRESHOLD = 0.05; + private static final double HEAP_THRESHOLD = 0.005; + private static final double HEAP_VARIANCE_THRESHOLD = 2.0; + private static final long CPU_TIME_THRESHOLD = 15; + private static final long ELAPSED_TIME_THRESHOLD = 30000; + } + + /** + * Defines the heap usage threshold (in percentage) for the sum of heap usages across all search shard tasks + * before in-flight cancellation is applied. + */ + private volatile double totalHeapThreshold; + public static final Setting SETTING_TOTAL_HEAP_THRESHOLD = Setting.doubleSetting( + "search_backpressure.search_shard_task.total_heap_threshold", + Defaults.TOTAL_HEAP_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the heap usage threshold (in percentage) for an individual task before it is considered for cancellation. + */ + private volatile double heapThreshold; + public static final Setting SETTING_HEAP_THRESHOLD = Setting.doubleSetting( + "search_backpressure.search_shard_task.heap_threshold", + Defaults.HEAP_THRESHOLD, + 0.0, + 1.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the heap usage variance for an individual task before it is considered for cancellation. + * A task is considered for cancellation when taskHeapUsage is greater than or equal to heapUsageMovingAverage * variance. + */ + private volatile double heapVarianceThreshold; + public static final Setting SETTING_HEAP_VARIANCE_THRESHOLD = Setting.doubleSetting( + "search_backpressure.search_shard_task.heap_variance", + Defaults.HEAP_VARIANCE_THRESHOLD, + 0.0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the CPU usage threshold (in millis) for an individual task before it is considered for cancellation. + */ + private volatile long cpuTimeThreshold; + public static final Setting SETTING_CPU_TIME_THRESHOLD = Setting.longSetting( + "search_backpressure.search_shard_task.cpu_time_threshold", + Defaults.CPU_TIME_THRESHOLD, + 0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + /** + * Defines the elapsed time threshold (in millis) for an individual task before it is considered for cancellation. + */ + private volatile long elapsedTimeThreshold; + public static final Setting SETTING_ELAPSED_TIME_THRESHOLD = Setting.longSetting( + "search_backpressure.search_shard_task.elapsed_time_threshold", + Defaults.ELAPSED_TIME_THRESHOLD, + 0, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + + public SearchShardTaskSettings(Settings settings, ClusterSettings clusterSettings) { + totalHeapThreshold = SETTING_TOTAL_HEAP_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_TOTAL_HEAP_THRESHOLD, this::setTotalHeapThreshold); + + heapThreshold = SETTING_HEAP_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_HEAP_THRESHOLD, this::setHeapThreshold); + + heapVarianceThreshold = SETTING_HEAP_VARIANCE_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_HEAP_VARIANCE_THRESHOLD, this::setHeapVarianceThreshold); + + cpuTimeThreshold = SETTING_CPU_TIME_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_CPU_TIME_THRESHOLD, this::setCpuTimeThreshold); + + elapsedTimeThreshold = SETTING_ELAPSED_TIME_THRESHOLD.get(settings); + clusterSettings.addSettingsUpdateConsumer(SETTING_ELAPSED_TIME_THRESHOLD, this::setElapsedTimeThreshold); + } + + public double getTotalHeapThreshold() { + return totalHeapThreshold; + } + + public long getTotalHeapThresholdBytes() { + return (long) (HEAP_SIZE_BYTES * getTotalHeapThreshold()); + } + + private void setTotalHeapThreshold(double totalHeapThreshold) { + this.totalHeapThreshold = totalHeapThreshold; + } + + public double getHeapThreshold() { + return heapThreshold; + } + + public long getHeapThresholdBytes() { + return (long) (HEAP_SIZE_BYTES * getHeapThreshold()); + } + + private void setHeapThreshold(double heapThreshold) { + this.heapThreshold = heapThreshold; + } + + public double getHeapVarianceThreshold() { + return heapVarianceThreshold; + } + + private void setHeapVarianceThreshold(double heapVarianceThreshold) { + this.heapVarianceThreshold = heapVarianceThreshold; + } + + public long getCpuTimeThreshold() { + return cpuTimeThreshold; + } + + private void setCpuTimeThreshold(long cpuTimeThreshold) { + this.cpuTimeThreshold = cpuTimeThreshold; + } + + public long getElapsedTimeThreshold() { + return elapsedTimeThreshold; + } + + private void setElapsedTimeThreshold(long elapsedTimeThreshold) { + this.elapsedTimeThreshold = elapsedTimeThreshold; + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/settings/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/settings/package-info.java new file mode 100644 index 0000000000000..a853a139b096b --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/settings/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains settings for search backpressure. + */ +package org.opensearch.search.backpressure.settings; diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/NodeDuressTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/NodeDuressTracker.java new file mode 100644 index 0000000000000..8e35c724a8fef --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/NodeDuressTracker.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.common.util.Streak; + +import java.util.function.BooleanSupplier; + +/** + * NodeDuressTracker is used to check if the node is in duress. + * + * @opensearch.internal + */ +public class NodeDuressTracker { + /** + * Tracks the number of consecutive breaches. + */ + private final Streak breaches = new Streak(); + + /** + * Predicate that returns true when the node is in duress. + */ + private final BooleanSupplier isNodeInDuress; + + public NodeDuressTracker(BooleanSupplier isNodeInDuress) { + this.isNodeInDuress = isNodeInDuress; + } + + /** + * Evaluates the predicate and returns the number of consecutive breaches. + */ + public int check() { + return breaches.record(isNodeInDuress.getAsBoolean()); + } +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTracker.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTracker.java new file mode 100644 index 0000000000000..8f1842efa5771 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/TaskResourceUsageTracker.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.tasks.Task; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; + +/** + * TaskResourceUsageTracker is used to track completions and cancellations of search related tasks. + * + * @opensearch.internal + */ +public abstract class TaskResourceUsageTracker { + /** + * Counts the number of cancellations made due to this tracker. + */ + private final AtomicLong cancellations = new AtomicLong(); + + public long incrementCancellations() { + return cancellations.incrementAndGet(); + } + + public long getCancellations() { + return cancellations.get(); + } + + /** + * Returns a unique name for this tracker. + */ + public abstract String name(); + + /** + * Notifies the tracker to update its state when a task execution completes. + */ + public abstract void update(Task task); + + /** + * Returns the cancellation reason for the given task, if it's eligible for cancellation. + */ + public abstract Optional cancellationReason(Task task); +} diff --git a/server/src/main/java/org/opensearch/search/backpressure/trackers/package-info.java b/server/src/main/java/org/opensearch/search/backpressure/trackers/package-info.java new file mode 100644 index 0000000000000..da0532421391e --- /dev/null +++ b/server/src/main/java/org/opensearch/search/backpressure/trackers/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * This package contains trackers to check if resource usage limits are breached on a node or task level. + */ +package org.opensearch.search.backpressure.trackers; diff --git a/server/src/main/java/org/opensearch/tasks/CancellableTask.java b/server/src/main/java/org/opensearch/tasks/CancellableTask.java index 439be2b630e84..336f5b1f4c244 100644 --- a/server/src/main/java/org/opensearch/tasks/CancellableTask.java +++ b/server/src/main/java/org/opensearch/tasks/CancellableTask.java @@ -71,7 +71,7 @@ public CancellableTask( /** * This method is called by the task manager when this task is cancelled. */ - final void cancel(String reason) { + public void cancel(String reason) { assert reason != null; if (cancelled.compareAndSet(false, true)) { this.reason = reason; diff --git a/server/src/main/java/org/opensearch/tasks/TaskCancellation.java b/server/src/main/java/org/opensearch/tasks/TaskCancellation.java new file mode 100644 index 0000000000000..2c302172cf6bc --- /dev/null +++ b/server/src/main/java/org/opensearch/tasks/TaskCancellation.java @@ -0,0 +1,108 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.ExceptionsHelper; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * TaskCancellation is a wrapper for a task and its cancellation reasons. + * + * @opensearch.internal + */ +public class TaskCancellation implements Comparable { + private final CancellableTask task; + private final List reasons; + private final List onCancelCallbacks; + + public TaskCancellation(CancellableTask task, List reasons, List onCancelCallbacks) { + this.task = task; + this.reasons = reasons; + this.onCancelCallbacks = onCancelCallbacks; + } + + public CancellableTask getTask() { + return task; + } + + public List getReasons() { + return reasons; + } + + public String getReasonString() { + return reasons.stream().map(Reason::getMessage).collect(Collectors.joining(", ")); + } + + /** + * Cancels the task and invokes all onCancelCallbacks. + */ + public void cancel() { + if (isEligibleForCancellation() == false) { + return; + } + + task.cancel(getReasonString()); + + List exceptions = new ArrayList<>(); + for (Runnable callback : onCancelCallbacks) { + try { + callback.run(); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); + } + + /** + * Returns the sum of all cancellation scores. + * + * A zero score indicates no reason to cancel the task. + * A task with a higher score suggests greater possibility of recovering the node when it is cancelled. + */ + public int totalCancellationScore() { + return reasons.stream().mapToInt(Reason::getCancellationScore).sum(); + } + + /** + * A task is eligible for cancellation if it has one or more cancellation reasons, and is not already cancelled. + */ + public boolean isEligibleForCancellation() { + return (task.isCancelled() == false) && (reasons.size() > 0); + } + + @Override + public int compareTo(TaskCancellation other) { + return Integer.compare(totalCancellationScore(), other.totalCancellationScore()); + } + + /** + * Represents the cancellation reason for a task. + */ + public static class Reason { + private final String message; + private final int cancellationScore; + + public Reason(String message, int cancellationScore) { + this.message = message; + this.cancellationScore = cancellationScore; + } + + public String getMessage() { + return message; + } + + public int getCancellationScore() { + return cancellationScore; + } + } +} diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index c3cad117390e4..b4806b531429e 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; @@ -50,6 +51,7 @@ public class TaskResourceTrackingService implements RunnableTaskExecutionListene private static final ThreadMXBean threadMXBean = (ThreadMXBean) ManagementFactory.getThreadMXBean(); private final ConcurrentMapLong resourceAwareTasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); + private final List taskCompletionListeners = new ArrayList<>(); private final ThreadPool threadPool; private volatile boolean taskResourceTrackingEnabled; @@ -116,6 +118,16 @@ public void stopTracking(Task task) { } finally { resourceAwareTasks.remove(task.getId()); } + + List exceptions = new ArrayList<>(); + for (TaskCompletionListener listener : taskCompletionListeners) { + try { + listener.onTaskCompleted(task); + } catch (Exception e) { + exceptions.add(e); + } + } + ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptions); } /** @@ -245,4 +257,14 @@ private ThreadContext.StoredContext addTaskIdToThreadContext(Task task) { return storedContext; } + /** + * Listener that gets invoked when a task execution completes. + */ + public interface TaskCompletionListener { + void onTaskCompleted(Task task); + } + + public void addTaskCompletionListener(TaskCompletionListener listener) { + this.taskCompletionListeners.add(listener); + } } diff --git a/server/src/test/java/org/opensearch/common/StreakTests.java b/server/src/test/java/org/opensearch/common/StreakTests.java new file mode 100644 index 0000000000000..80080ad3e4027 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/StreakTests.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common; + +import org.opensearch.common.util.Streak; +import org.opensearch.test.OpenSearchTestCase; + +public class StreakTests extends OpenSearchTestCase { + + public void testStreak() { + Streak streak = new Streak(); + + // Streak starts with zero. + assertEquals(0, streak.length()); + + // Streak increases on successive successful events. + streak.record(true); + assertEquals(1, streak.length()); + streak.record(true); + assertEquals(2, streak.length()); + streak.record(true); + assertEquals(3, streak.length()); + + // Streak resets to zero after an unsuccessful event. + streak.record(false); + assertEquals(0, streak.length()); + } +} diff --git a/server/src/test/java/org/opensearch/common/util/MovingAverageTests.java b/server/src/test/java/org/opensearch/common/util/MovingAverageTests.java new file mode 100644 index 0000000000000..415058992e081 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/util/MovingAverageTests.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.opensearch.test.OpenSearchTestCase; + +public class MovingAverageTests extends OpenSearchTestCase { + + public void testMovingAverage() { + MovingAverage ma = new MovingAverage(5); + + // No observations + assertEquals(0.0, ma.getAverage(), 0.0); + assertEquals(0, ma.getCount()); + + // Not enough observations + ma.record(1); + ma.record(2); + ma.record(3); + assertEquals(2.0, ma.getAverage(), 0.0); + assertEquals(3, ma.getCount()); + assertFalse(ma.isReady()); + + // Enough observations + ma.record(4); + ma.record(5); + ma.record(6); + assertEquals(4, ma.getAverage(), 0.0); + assertEquals(6, ma.getCount()); + assertTrue(ma.isReady()); + } + + public void testMovingAverageWithZeroSize() { + try { + new MovingAverage(0); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("window size must be greater than zero")); + return; + } + + fail("exception should have been thrown"); + } +} diff --git a/server/src/test/java/org/opensearch/common/util/TokenBucketTests.java b/server/src/test/java/org/opensearch/common/util/TokenBucketTests.java new file mode 100644 index 0000000000000..a52e97cdd835c --- /dev/null +++ b/server/src/test/java/org/opensearch/common/util/TokenBucketTests.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.LongSupplier; + +public class TokenBucketTests extends OpenSearchTestCase { + + public void testTokenBucket() { + AtomicLong mockTimeNanos = new AtomicLong(); + LongSupplier mockTimeNanosSupplier = mockTimeNanos::get; + + // Token bucket that refills at 2 tokens/second and allows short bursts up to 3 operations. + TokenBucket tokenBucket = new TokenBucket(mockTimeNanosSupplier, 2.0 / TimeUnit.SECONDS.toNanos(1), 3); + + // Three operations succeed, fourth fails. + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Clock moves ahead by one second. Two operations succeed, third fails. + mockTimeNanos.addAndGet(TimeUnit.SECONDS.toNanos(1)); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Clock moves ahead by half a second. One operation succeeds, second fails. + mockTimeNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(500)); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Clock moves ahead by many seconds, but the token bucket should be capped at the 'burst' capacity. + mockTimeNanos.addAndGet(TimeUnit.SECONDS.toNanos(10)); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertTrue(tokenBucket.request()); + assertFalse(tokenBucket.request()); + + // Ability to request fractional tokens. + mockTimeNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(250)); + assertFalse(tokenBucket.request(1.0)); + assertTrue(tokenBucket.request(0.5)); + } + + public void testTokenBucketWithInvalidRate() { + try { + new TokenBucket(System::nanoTime, -1, 2); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("rate must be greater than zero")); + return; + } + + fail("exception should have been thrown"); + } + + public void testTokenBucketWithInvalidBurst() { + try { + new TokenBucket(System::nanoTime, 1, 0); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("burst must be greater than zero")); + return; + } + + fail("exception should have been thrown"); + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java b/server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java new file mode 100644 index 0000000000000..ac5a8229718ba --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/SearchBackpressureServiceTests.java @@ -0,0 +1,213 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.backpressure.settings.SearchBackpressureSettings; +import org.opensearch.search.backpressure.settings.SearchShardTaskSettings; +import org.opensearch.search.backpressure.trackers.NodeDuressTracker; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskCancellation; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.search.backpressure.SearchBackpressureTestHelpers.createMockTaskWithResourceStats; + +public class SearchBackpressureServiceTests extends OpenSearchTestCase { + + public void testIsNodeInDuress() { + TaskResourceTrackingService mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); + + AtomicReference cpuUsage = new AtomicReference<>(); + AtomicReference heapUsage = new AtomicReference<>(); + NodeDuressTracker cpuUsageTracker = new NodeDuressTracker(() -> cpuUsage.get() >= 0.5); + NodeDuressTracker heapUsageTracker = new NodeDuressTracker(() -> heapUsage.get() >= 0.5); + + SearchBackpressureSettings settings = new SearchBackpressureSettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + SearchBackpressureService service = new SearchBackpressureService( + settings, + mockTaskResourceTrackingService, + mockThreadPool, + System::nanoTime, + List.of(cpuUsageTracker, heapUsageTracker), + Collections.emptyList() + ); + + // Node not in duress. + cpuUsage.set(0.0); + heapUsage.set(0.0); + assertFalse(service.isNodeInDuress()); + + // Node in duress; but not for many consecutive data points. + cpuUsage.set(1.0); + heapUsage.set(1.0); + assertFalse(service.isNodeInDuress()); + + // Node in duress for consecutive data points. + assertFalse(service.isNodeInDuress()); + assertTrue(service.isNodeInDuress()); + + // Node not in duress anymore. + cpuUsage.set(0.0); + heapUsage.set(0.0); + assertFalse(service.isNodeInDuress()); + } + + public void testTrackerStateUpdateOnTaskCompletion() { + TaskResourceTrackingService mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); + LongSupplier mockTimeNanosSupplier = () -> TimeUnit.SECONDS.toNanos(1234); + TaskResourceUsageTracker mockTaskResourceUsageTracker = mock(TaskResourceUsageTracker.class); + + SearchBackpressureSettings settings = new SearchBackpressureSettings( + Settings.EMPTY, + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ); + + SearchBackpressureService service = new SearchBackpressureService( + settings, + mockTaskResourceTrackingService, + mockThreadPool, + mockTimeNanosSupplier, + Collections.emptyList(), + List.of(mockTaskResourceUsageTracker) + ); + + // Record task completions to update the tracker state. Tasks other than SearchShardTask are ignored. + service.onTaskCompleted(createMockTaskWithResourceStats(CancellableTask.class, 100, 200)); + for (int i = 0; i < 100; i++) { + service.onTaskCompleted(createMockTaskWithResourceStats(SearchShardTask.class, 100, 200)); + } + assertEquals(100, service.getState().getCompletionCount()); + verify(mockTaskResourceUsageTracker, times(100)).update(any()); + } + + public void testInFlightCancellation() { + TaskResourceTrackingService mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + ThreadPool mockThreadPool = mock(ThreadPool.class); + AtomicLong mockTime = new AtomicLong(0); + LongSupplier mockTimeNanosSupplier = mockTime::get; + NodeDuressTracker mockNodeDuressTracker = new NodeDuressTracker(() -> true); + + TaskResourceUsageTracker mockTaskResourceUsageTracker = new TaskResourceUsageTracker() { + @Override + public String name() { + return "mock_tracker"; + } + + @Override + public void update(Task task) {} + + @Override + public Optional cancellationReason(Task task) { + if (task.getTotalResourceStats().getCpuTimeInNanos() < 300) { + return Optional.empty(); + } + + return Optional.of(new TaskCancellation.Reason("limits exceeded", 5)); + } + }; + + // Mocking 'settings' with predictable rate limiting thresholds. + SearchBackpressureSettings settings = spy( + new SearchBackpressureSettings( + Settings.builder() + .put(SearchBackpressureSettings.SETTING_ENFORCED.getKey(), true) + .put(SearchBackpressureSettings.SETTING_CANCELLATION_RATIO.getKey(), 0.1) + .put(SearchBackpressureSettings.SETTING_CANCELLATION_RATE.getKey(), 0.003) + .put(SearchBackpressureSettings.SETTING_CANCELLATION_BURST.getKey(), 10.0) + .build(), + new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + ) + ); + + SearchBackpressureService service = new SearchBackpressureService( + settings, + mockTaskResourceTrackingService, + mockThreadPool, + mockTimeNanosSupplier, + List.of(mockNodeDuressTracker), + List.of(mockTaskResourceUsageTracker) + ); + + // Run two iterations so that node is marked 'in duress' from the third iteration onwards. + service.doRun(); + service.doRun(); + + // Mocking 'settings' with predictable searchHeapThresholdBytes so that cancellation logic doesn't get skipped. + long taskHeapUsageBytes = 500; + SearchShardTaskSettings shardTaskSettings = mock(SearchShardTaskSettings.class); + when(shardTaskSettings.getTotalHeapThresholdBytes()).thenReturn(taskHeapUsageBytes); + when(settings.getSearchShardTaskSettings()).thenReturn(shardTaskSettings); + + // Create a mix of low and high resource usage tasks (60 low + 15 high resource usage tasks). + Map activeTasks = new HashMap<>(); + for (long i = 0; i < 75; i++) { + if (i % 5 == 0) { + activeTasks.put(i, createMockTaskWithResourceStats(SearchShardTask.class, 500, taskHeapUsageBytes)); + } else { + activeTasks.put(i, createMockTaskWithResourceStats(SearchShardTask.class, 100, taskHeapUsageBytes)); + } + } + doReturn(activeTasks).when(mockTaskResourceTrackingService).getResourceAwareTasks(); + + // There are 15 tasks eligible for cancellation but only 10 will be cancelled (burst limit). + service.doRun(); + assertEquals(10, service.getState().getCancellationCount()); + assertEquals(1, service.getState().getLimitReachedCount()); + + // If the clock or completed task count haven't made sufficient progress, we'll continue to be rate-limited. + service.doRun(); + assertEquals(10, service.getState().getCancellationCount()); + assertEquals(2, service.getState().getLimitReachedCount()); + + // Simulate task completion to replenish some tokens. + // This will add 2 tokens (task count delta * cancellationRatio) to 'rateLimitPerTaskCompletion'. + for (int i = 0; i < 20; i++) { + service.onTaskCompleted(createMockTaskWithResourceStats(SearchShardTask.class, 100, taskHeapUsageBytes)); + } + service.doRun(); + assertEquals(12, service.getState().getCancellationCount()); + assertEquals(3, service.getState().getLimitReachedCount()); + + // Fast-forward the clock by one second to replenish some tokens. + // This will add 3 tokens (time delta * rate) to 'rateLimitPerTime'. + mockTime.addAndGet(TimeUnit.SECONDS.toNanos(1)); + service.doRun(); + assertEquals(15, service.getState().getCancellationCount()); + assertEquals(3, service.getState().getLimitReachedCount()); // no more tasks to cancel; limit not reached + } +} diff --git a/server/src/test/java/org/opensearch/search/backpressure/trackers/NodeDuressTrackerTests.java b/server/src/test/java/org/opensearch/search/backpressure/trackers/NodeDuressTrackerTests.java new file mode 100644 index 0000000000000..472ba95566523 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/backpressure/trackers/NodeDuressTrackerTests.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure.trackers; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.atomic.AtomicReference; + +public class NodeDuressTrackerTests extends OpenSearchTestCase { + + public void testNodeDuressTracker() { + AtomicReference cpuUsage = new AtomicReference<>(0.0); + NodeDuressTracker tracker = new NodeDuressTracker(() -> cpuUsage.get() >= 0.5); + + // Node not in duress. + assertEquals(0, tracker.check()); + + // Node in duress; the streak must keep increasing. + cpuUsage.set(0.7); + assertEquals(1, tracker.check()); + assertEquals(2, tracker.check()); + assertEquals(3, tracker.check()); + + // Node not in duress anymore. + cpuUsage.set(0.3); + assertEquals(0, tracker.check()); + assertEquals(0, tracker.check()); + } +} diff --git a/server/src/test/java/org/opensearch/tasks/TaskCancellationTests.java b/server/src/test/java/org/opensearch/tasks/TaskCancellationTests.java new file mode 100644 index 0000000000000..d6a82702e074d --- /dev/null +++ b/server/src/test/java/org/opensearch/tasks/TaskCancellationTests.java @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.tasks; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.search.backpressure.trackers.TaskResourceUsageTracker; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +public class TaskCancellationTests extends OpenSearchTestCase { + + public void testTaskCancellation() { + SearchShardTask mockTask = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + + TaskResourceUsageTracker mockTracker1 = createMockTaskResourceUsageTracker("mock_tracker_1"); + TaskResourceUsageTracker mockTracker2 = createMockTaskResourceUsageTracker("mock_tracker_2"); + TaskResourceUsageTracker mockTracker3 = createMockTaskResourceUsageTracker("mock_tracker_3"); + + List reasons = new ArrayList<>(); + List callbacks = List.of(mockTracker1::incrementCancellations, mockTracker2::incrementCancellations); + TaskCancellation taskCancellation = new TaskCancellation(mockTask, reasons, callbacks); + + // Task does not have any reason to be cancelled. + assertEquals(0, taskCancellation.totalCancellationScore()); + assertFalse(taskCancellation.isEligibleForCancellation()); + taskCancellation.cancel(); + assertEquals(0, mockTracker1.getCancellations()); + assertEquals(0, mockTracker2.getCancellations()); + assertEquals(0, mockTracker3.getCancellations()); + + // Task has one or more reasons to be cancelled. + reasons.add(new TaskCancellation.Reason("limits exceeded 1", 10)); + reasons.add(new TaskCancellation.Reason("limits exceeded 2", 20)); + reasons.add(new TaskCancellation.Reason("limits exceeded 3", 5)); + assertEquals(35, taskCancellation.totalCancellationScore()); + assertTrue(taskCancellation.isEligibleForCancellation()); + + // Cancel the task and validate the cancellation reason and invocation of callbacks. + taskCancellation.cancel(); + assertTrue(mockTask.getReasonCancelled().contains("limits exceeded 1, limits exceeded 2, limits exceeded 3")); + assertEquals(1, mockTracker1.getCancellations()); + assertEquals(1, mockTracker2.getCancellations()); + assertEquals(0, mockTracker3.getCancellations()); + } + + private static TaskResourceUsageTracker createMockTaskResourceUsageTracker(String name) { + return new TaskResourceUsageTracker() { + @Override + public String name() { + return name; + } + + @Override + public void update(Task task) {} + + @Override + public Optional cancellationReason(Task task) { + return Optional.empty(); + } + }; + } +} diff --git a/test/framework/src/main/java/org/opensearch/search/backpressure/SearchBackpressureTestHelpers.java b/test/framework/src/main/java/org/opensearch/search/backpressure/SearchBackpressureTestHelpers.java new file mode 100644 index 0000000000000..ba3653d0b4a84 --- /dev/null +++ b/test/framework/src/main/java/org/opensearch/search/backpressure/SearchBackpressureTestHelpers.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.backpressure; + +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.TaskResourceUsage; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SearchBackpressureTestHelpers extends OpenSearchTestCase { + + public static T createMockTaskWithResourceStats(Class type, long cpuUsage, long heapUsage) { + return createMockTaskWithResourceStats(type, cpuUsage, heapUsage, 0); + } + + public static T createMockTaskWithResourceStats( + Class type, + long cpuUsage, + long heapUsage, + long startTimeNanos + ) { + T task = mock(type); + when(task.getTotalResourceStats()).thenReturn(new TaskResourceUsage(cpuUsage, heapUsage)); + when(task.getStartTimeNanos()).thenReturn(startTimeNanos); + + AtomicBoolean isCancelled = new AtomicBoolean(false); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(task).cancel(anyString()); + doAnswer(invocation -> isCancelled.get()).when(task).isCancelled(); + + return task; + } +}