diff --git a/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java b/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java
index 1905a8d76405..ef7efecaec89 100644
--- a/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java
+++ b/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java
@@ -190,7 +190,7 @@ public interface Expirable
*
* @return the expiration time in nanoseconds, or {@link Long#MAX_VALUE} if this entity does not expire
*/
- public long getExpireNanoTime();
+ long getExpireNanoTime();
}
private class Timeouts extends CyclicTimeout
diff --git a/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml b/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml
new file mode 100644
index 000000000000..68b4b9de0087
--- /dev/null
+++ b/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml
@@ -0,0 +1,75 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/jetty-core/jetty-server/src/main/config/modules/dos.mod b/jetty-core/jetty-server/src/main/config/modules/dos.mod
new file mode 100644
index 000000000000..45d5a9f5da4d
--- /dev/null
+++ b/jetty-core/jetty-server/src/main/config/modules/dos.mod
@@ -0,0 +1,62 @@
+# DO NOT EDIT THIS FILE - See: https://eclipse.dev/jetty/documentation/
+
+[description]
+Enables the DosHandler for the server.
+
+[tags]
+connector
+
+[depend]
+server
+
+[xml]
+etc/jetty-dos.xml
+
+[ini-template]
+
+## The algorithm to use for obtaining an remote client identifier from a Request: ID_FROM_REMOTE_ADDRESS, ID_FROM_REMOTE_PORT, ID_FROM_REMOTE_ADDRESS_PORT, ID_FROM_CONNECTION
+#jetty.dos.id.type=ID_FROM_REMOTE_ADDRESS
+#jetty.dos.id.class=org.eclipse.jetty.server.handler.DosHandler
+
+## The class to use to create Tracker instances to track the rate of requests
+#jetty.dos.trackerFactory=org.eclipse.jetty.server.handler.DoSHandler$LeakingBucketTrackerFactory
+
+## The Handler class to use to reject DOS requests
+#jetty.dos.rejectHandler=org.eclipse.jetty.server.handler.DoSHandler$DelayedRejectHandler
+
+## The maximum requests per second per client
+#jetty.dos.leakingBucketTracker.maxRequestsPerSecond=100
+
+## The size of the leaky bucket. Larger buckets allow longer bursts before enforcing the rate
+#jetty.dos.leakingBucketTracker.bucketSize=100
+
+## The time in seconds to retain an empty bucket.
+#jetty.dos.leakingBucketTracker.idleTimeout=1
+
+## The period to delay dos requests before rejecting them.
+#jetty.dos.rejectHandler.delayed.delayMs=1000
+
+## The maximum number of requests to be held in the delay queue
+#jetty.dos.rejectHandler.delayed.maxDelayQueue=1000
+
+## The maximum number of clients to track; or -1 for a default value; or 0 for unlimited
+#jetty.dos.maxTrackers=100000
+
+## Should untracked requests (due to maxTrackers) be rejected or allowed
+#jetty.dos.rejectUntracked=false
+
+## The status code used to reject requests; or 0 to abort the request; or -1 for a default
+#jetty.dos.rejectStatus=429
+
+## List of InetAddress patterns to include
+#jetty.dos.include.inet=10.10.10-14.0-128
+
+## List of InetAddressPatterns to exclude
+#jetty.dos.exclude.inet=10.10.10-14.0-128
+
+## List of path patterns to include
+#jetty.dos.include.path=/context/*
+
+## List of path to exclude
+#jetty.dos.exclude.path=/context/*
+
diff --git a/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java b/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java
new file mode 100644
index 000000000000..5aaa5d2be2a6
--- /dev/null
+++ b/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java
@@ -0,0 +1,522 @@
+//
+// ========================================================================
+// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
+//
+// This program and the accompanying materials are made available under the
+// terms of the Eclipse Public License v. 2.0 which is available at
+// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
+// which is available at https://www.apache.org/licenses/LICENSE-2.0.
+//
+// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
+// ========================================================================
+//
+
+package org.eclipse.jetty.server.handler;
+
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.time.Duration;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+
+import org.eclipse.jetty.http.HttpStatus;
+import org.eclipse.jetty.io.CyclicTimeouts;
+import org.eclipse.jetty.server.ConnectionMetaData;
+import org.eclipse.jetty.server.Handler;
+import org.eclipse.jetty.server.Request;
+import org.eclipse.jetty.server.Response;
+import org.eclipse.jetty.server.Server;
+import org.eclipse.jetty.util.Callback;
+import org.eclipse.jetty.util.NanoTime;
+import org.eclipse.jetty.util.annotation.ManagedObject;
+import org.eclipse.jetty.util.annotation.Name;
+import org.eclipse.jetty.util.thread.AutoLock;
+import org.eclipse.jetty.util.thread.Scheduler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ *
A Denial of Service Handler that protects from attacks by limiting the request rate from remote clients.
+ */
+@ManagedObject("DoS Prevention Handler")
+public class DoSHandler extends ConditionalHandler.ElseNext
+{
+ private static final Logger LOG = LoggerFactory.getLogger(DoSHandler.class);
+
+ /**
+ * A {@link Function} to create a remote client identifier from the remote address and remote port of a {@link Request}.
+ */
+ public static final Function ID_FROM_REMOTE_ADDRESS_PORT = request ->
+ {
+ SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
+ if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress)
+ return inetSocketAddress.toString();
+ return remoteSocketAddress.toString();
+ };
+
+ /**
+ * A {@link Function} to create a remote client identifier from the remote address of a {@link Request}.
+ */
+ public static final Function ID_FROM_REMOTE_ADDRESS = request ->
+ {
+ SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
+ if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress)
+ return inetSocketAddress.getAddress().toString();
+ return remoteSocketAddress.toString();
+ };
+
+ /**
+ * A {@link Function} to create a remote client identifier from the remote port of a {@link Request}.
+ * Useful if there is an untrusted intermediary, where the remote port can be a surrogate for the connection.
+ */
+ public static final Function ID_FROM_REMOTE_PORT = request ->
+ {
+ SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress();
+ if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress)
+ return Integer.toString(inetSocketAddress.getPort());
+ return remoteSocketAddress.toString();
+ };
+
+ /**
+ * A {@link Function} to create a remote client identifier from {@link ConnectionMetaData#getId()} of a {@link Request}.
+ */
+ public static final Function ID_FROM_CONNECTION = request -> request.getConnectionMetaData().getId();
+
+ private final Map _trackers = new ConcurrentHashMap<>();
+ private final Function _clientIdFn;
+ private final Tracker.Factory _trackerFactory;
+ private final Request.Handler _rejectHandler;
+ private final int _maxTrackers;
+ private final boolean _rejectUntracked;
+ private CyclicTimeouts _cyclicTimeouts;
+
+ /**
+ * @param trackerFactory Factory to create a Tracker
+ */
+ public DoSHandler(@Name("trackerFactory") Tracker.Factory trackerFactory)
+ {
+ this(null, trackerFactory, null, -1);
+ }
+
+ /**
+ * @param clientIdFn Function to extract a remote client identifier from a request.
+ * @param trackerFactory Factory to create a Tracker
+ * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default.
+ * @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited.
+ * If this limit is exceeded, then requests from additional remote clients are rejected.
+ */
+ public DoSHandler(
+ @Name("clientIdFn") Function clientIdFn,
+ @Name("trackerFactory") Tracker.Factory trackerFactory,
+ @Name("rejectHandler") Request.Handler rejectHandler,
+ @Name("maxTrackers") int maxTrackers)
+ {
+ this(null, clientIdFn, trackerFactory, rejectHandler, maxTrackers);
+ }
+
+ /**
+ * @param handler Then next {@link Handler} or {@code null}
+ * @param clientIdFn Function to extract a remote client identifier from a request.
+ * @param trackerFactory Factory to create a Tracker
+ * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default.
+ * @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited.
+ * If this limit is exceeded, then requests from additional remote clients are rejected.
+ */
+ public DoSHandler(
+ @Name("handler") Handler handler,
+ @Name("clientIdFn") Function clientIdFn,
+ @Name("trackerFactory") Tracker.Factory trackerFactory,
+ @Name("rejectHandler") Request.Handler rejectHandler,
+ @Name("maxTrackers") int maxTrackers)
+ {
+ this(handler, clientIdFn, trackerFactory, rejectHandler, maxTrackers, false);
+ }
+
+ /**
+ * @param handler Then next {@link Handler} or {@code null}
+ * @param clientIdFn Function to extract a remote client identifier from a request.
+ * @param trackerFactory Factory to create a Tracker
+ * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default.
+ * @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited.
+ * If this limit is exceeded, then requests from additional remote clients are rejected.
+ */
+ public DoSHandler(
+ @Name("handler") Handler handler,
+ @Name("clientIdFn") Function clientIdFn,
+ @Name("trackerFactory") Tracker.Factory trackerFactory,
+ @Name("rejectHandler") Request.Handler rejectHandler,
+ @Name("maxTrackers") int maxTrackers,
+ @Name("rejectUntracked") boolean rejectUntracked)
+ {
+ super(handler);
+ installBean(_trackers);
+ _clientIdFn = Objects.requireNonNullElse(clientIdFn, ID_FROM_REMOTE_ADDRESS);
+ installBean(_clientIdFn);
+ _trackerFactory = Objects.requireNonNull(trackerFactory);
+ installBean(_trackerFactory);
+ // default max trackers to a large, effectively infinite number, but ultimately bounded.
+ _maxTrackers = maxTrackers < 0 ? 100_000 : maxTrackers;
+ _rejectHandler = Objects.requireNonNullElseGet(rejectHandler, StatusRejectHandler::new);
+ installBean(_rejectHandler);
+ _rejectUntracked = rejectUntracked;
+ }
+
+ @Override
+ public void setServer(Server server)
+ {
+ super.setServer(server);
+ if (_rejectHandler instanceof Handler handler)
+ handler.setServer(server);
+ }
+
+ @Override
+ protected boolean onConditionsMet(Request request, Response response, Callback callback) throws Exception
+ {
+
+ // Calculate an id for the request (which may be global empty string).
+ String id = _clientIdFn.apply(request);
+
+ // Reject or handle untracked request
+ if (id == null)
+ id = "";
+
+ // Obtain a tracker, creating a new one if necessary (and not too many)
+ // Trackers are removed if CyclicTimeouts#onExpired returns true.
+ Tracker tracker = _trackers.computeIfAbsent(id, this::newTracker);
+
+ // If we have too many trackers, then we will have a null tracker
+ if (tracker == null)
+ return _rejectUntracked ? _rejectHandler.handle(request, response, callback) : nextHandler(request, response, callback);
+
+ // IS the request allowed by the tracker?
+ boolean allowed = tracker.onRequest(NanoTime.now());
+ if (LOG.isDebugEnabled())
+ LOG.debug("allowed={} {}", allowed, tracker);
+ if (allowed)
+ return nextHandler(request, response, callback);
+
+ // Otherwise reject the request as it is over rate
+ return _rejectHandler.handle(request, response, callback);
+ }
+
+ Tracker newTracker(String id)
+ {
+ if (_maxTrackers > 0 && _trackers.size() >= _maxTrackers)
+ return null;
+
+ Tracker tracker = _trackerFactory.newTracker(id);
+ _cyclicTimeouts.schedule(tracker);
+ return tracker;
+ }
+
+ @Override
+ protected void doStart() throws Exception
+ {
+ _cyclicTimeouts = new CyclicTimeouts<>(getServer().getScheduler())
+ {
+ @Override
+ protected Iterator iterator()
+ {
+ return _trackers.values().iterator();
+ }
+
+ @Override
+ protected boolean onExpired(Tracker tracker)
+ {
+ return true;
+ }
+ };
+ addBean(_cyclicTimeouts);
+ super.doStart();
+ }
+
+ @Override
+ protected void doStop() throws Exception
+ {
+ super.doStop();
+ removeBean(_cyclicTimeouts);
+ _cyclicTimeouts.destroy();
+ _cyclicTimeouts = null;
+ }
+
+ /**
+ * A RateTracker is associated with an id, and stores request rate data.
+ */
+ public interface Tracker extends CyclicTimeouts.Expirable
+ {
+ /**
+ * Add a request to the tracker and check the rate limit
+ *
+ * @param now The timestamp of the request
+ * @return {@code true} if the request is below the limit
+ */
+ boolean onRequest(long now);
+
+ interface Factory
+ {
+ Tracker newTracker(String id);
+ }
+ }
+
+ /**
+ * The Tracker implements the classic Leaky Bucket Algorithm.
+ */
+ public static class LeakingBucketTrackerFactory implements Tracker.Factory
+ {
+ private final int _maxRequestsPerSecond;
+ private final int _bucketSize;
+ private final long _nanosPerDrip;
+ private final long _idleTimeout;
+
+ /**
+ * @param maxRequestsPerSecond the maximum requests per second allowed by this tracker
+ */
+ public LeakingBucketTrackerFactory(
+ @Name("maxRequestsPerSecond") int maxRequestsPerSecond)
+ {
+ this(maxRequestsPerSecond, -1, null);
+ }
+
+ /**
+ * @param maxRequestsPerSecond the maximum requests per second allowed by this tracker
+ * @param bucketSize the size of the bucket in request/drips, which is effectively the burst capacity, giving the number
+ * of request that can be handled in excess of the short term rate, before being rejected.
+ * Use -1 for a heuristic value.
+ * @param idleTimeout The period to keep an empty bucket before removal, or null to remove a bucket immediately when
+ * empty
+ */
+ public LeakingBucketTrackerFactory(
+ @Name("maxRequestsPerSecond") int maxRequestsPerSecond,
+ @Name("bucketSize") int bucketSize,
+ @Name("idleTimeout") Duration idleTimeout)
+ {
+ _maxRequestsPerSecond = maxRequestsPerSecond;
+ _nanosPerDrip = TimeUnit.SECONDS.toNanos(1) / _maxRequestsPerSecond;
+ _bucketSize = (bucketSize < 0)
+ ? _maxRequestsPerSecond
+ : bucketSize;
+ _idleTimeout = idleTimeout == null ? 0 : idleTimeout.toNanos();
+ }
+
+ @Override
+ public Tracker newTracker(String id)
+ {
+ return new LeakingBucketTracker(id);
+ }
+
+ private class LeakingBucketTracker implements Tracker
+ {
+ private final AutoLock _lock = new AutoLock();
+ private final String _id;
+ private long _lastDripNanoTime;
+ private long _expireNanoTime;
+ private int _bucket;
+
+ public LeakingBucketTracker(String id)
+ {
+ _id = id;
+ long now = NanoTime.now();
+ _lastDripNanoTime = now;
+ _expireNanoTime = now + _nanosPerDrip + _idleTimeout;
+ }
+
+ @Override
+ public long getExpireNanoTime()
+ {
+ try (AutoLock ignored = _lock.lock())
+ {
+ return _expireNanoTime;
+ }
+ }
+
+ @Override
+ public boolean onRequest(long now)
+ {
+ try (AutoLock ignored = _lock.lock())
+ {
+ long elapsedSinceLastDrip = NanoTime.elapsed(_lastDripNanoTime, now);
+ long drips = elapsedSinceLastDrip / _nanosPerDrip;
+ _lastDripNanoTime = _lastDripNanoTime + drips * _nanosPerDrip;
+ _bucket = Math.min(_bucketSize, Math.toIntExact(Math.max(0L, _bucket - drips) + 1));
+ _expireNanoTime = now + _bucket * _nanosPerDrip + _idleTimeout;
+ return _bucket < _bucketSize;
+ }
+ }
+
+ @Override
+ public String toString()
+ {
+ try (AutoLock ignored = _lock.lock())
+ {
+ return "%s@%s{%d/%d}".formatted(getClass().getSimpleName(), _id, _bucket, _maxRequestsPerSecond);
+ }
+ }
+ }
+ }
+
+ /**
+ * A Handler to reject DoS requests with a status code or failure.
+ */
+ public static class StatusRejectHandler implements Request.Handler
+ {
+ private final int _status;
+
+ public StatusRejectHandler()
+ {
+ this(-1);
+ }
+
+ /**
+ * @param status The status used to reject a request, or 0 to fail the request or -1 for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}.
+ */
+ public StatusRejectHandler(int status)
+ {
+ _status = status >= 0 ? status : HttpStatus.TOO_MANY_REQUESTS_429;
+ if (_status != 0 && _status != HttpStatus.OK_200 && !HttpStatus.isClientError(_status) && !HttpStatus.isServerError(_status))
+ throw new IllegalArgumentException("status must be a client or server error");
+ }
+
+ @Override
+ public boolean handle(Request request, Response response, Callback callback) throws Exception
+ {
+ if (_status == 0)
+ callback.failed(new RejectedExecutionException());
+ else
+ Response.writeError(request, response, callback, _status);
+ return true;
+ }
+ }
+
+ /**
+ * A Handler to reject DoS requests after first delaying them.
+ */
+ public static class DelayedRejectHandler extends Handler.Abstract
+ {
+ private record Exchange(Request request, Response response, Callback callback)
+ {
+ }
+
+ private final AutoLock _lock = new AutoLock();
+ private final Deque _delayQueue = new ArrayDeque<>();
+ private final int _maxDelayQueue;
+ private final long _delayMs;
+ private final Request.Handler _reject;
+ private Scheduler _scheduler;
+
+ public DelayedRejectHandler()
+ {
+ this(-1, -1, null);
+ }
+
+ /**
+ * @param delayMs The delay in milliseconds to hold rejected requests before sending a response or -1 for a default (1000ms)
+ * @param maxDelayQueue The maximum number of delayed requests to hold or -1 for a default (1000ms).
+ * @param reject The {@link Request.Handler} used to reject {@link Request}s or null for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}).
+ */
+ public DelayedRejectHandler(
+ @Name("delayMs") long delayMs,
+ @Name("maxDelayQueue") int maxDelayQueue,
+ @Name("reject") Request.Handler reject)
+ {
+ _delayMs = delayMs >= 0 ? delayMs : 1000;
+ _maxDelayQueue = maxDelayQueue >= 0 ? maxDelayQueue : 1000;
+ _reject = Objects.requireNonNullElseGet(reject, () -> new StatusRejectHandler(HttpStatus.TOO_MANY_REQUESTS_429));
+ }
+
+ @Override
+ protected void doStart() throws Exception
+ {
+ super.doStart();
+ _scheduler = getServer().getScheduler();
+ addBean(_scheduler);
+ }
+
+ @Override
+ protected void doStop() throws Exception
+ {
+ super.doStop();
+ removeBean(_scheduler);
+ _scheduler = null;
+ }
+
+ @Override
+ public boolean handle(Request request, Response response, Callback callback) throws Exception
+ {
+ List rejects = null;
+ try (AutoLock ignored = _lock.lock())
+ {
+ while (_delayQueue.size() >= _maxDelayQueue)
+ {
+ Exchange exchange = _delayQueue.removeFirst();
+ if (rejects == null)
+ rejects = new ArrayList<>();
+ rejects.add(exchange);
+ }
+
+ if (_delayQueue.isEmpty())
+ _scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS);
+ _delayQueue.addLast(new Exchange(request, response, callback));
+ }
+
+ reject(rejects);
+
+ return true;
+ }
+
+ private void onTick()
+ {
+ long expired = NanoTime.now() - TimeUnit.MILLISECONDS.toNanos(_delayMs);
+
+ List rejects = null;
+ try (AutoLock ignored = _lock.lock())
+ {
+ Iterator iterator = _delayQueue.iterator();
+ while (iterator.hasNext())
+ {
+ Exchange exchange = iterator.next();
+ if (NanoTime.isBeforeOrSame(exchange.request.getBeginNanoTime(), expired))
+ {
+ iterator.remove();
+
+ if (rejects == null)
+ rejects = new ArrayList<>();
+ rejects.add(exchange);
+ }
+ }
+
+ if (!_delayQueue.isEmpty())
+ _scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS);
+ }
+
+ reject(rejects);
+ }
+
+ private void reject(List rejects)
+ {
+ if (rejects != null)
+ {
+ for (Exchange exchange : rejects)
+ {
+ try
+ {
+ if (!_reject.handle(exchange.request, exchange.response, exchange.callback))
+ exchange.callback.failed(new RejectedExecutionException());
+ }
+ catch (Throwable t)
+ {
+ exchange.callback.failed(t);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java b/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java
new file mode 100644
index 000000000000..8dd8e2ae249a
--- /dev/null
+++ b/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java
@@ -0,0 +1,322 @@
+//
+// ========================================================================
+// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
+//
+// This program and the accompanying materials are made available under the
+// terms of the Eclipse Public License v. 2.0 which is available at
+// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
+// which is available at https://www.apache.org/licenses/LICENSE-2.0.
+//
+// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
+// ========================================================================
+//
+
+package org.eclipse.jetty.server.handler;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Stream;
+
+import org.awaitility.Awaitility;
+import org.eclipse.jetty.server.LocalConnector;
+import org.eclipse.jetty.server.Server;
+import org.eclipse.jetty.util.NanoTime;
+import org.eclipse.jetty.util.component.LifeCycle;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class DoSHandlerTest
+{
+ public static Stream factories()
+ {
+ return Stream.of(
+ Arguments.of(new DoSHandler.LeakingBucketTrackerFactory(100))
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testTrackerSteadyBelowRate(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ for (int sample = 0; sample < 400; sample++)
+ {
+ boolean exceeded = !tracker.onRequest(now);
+ assertFalse(exceeded);
+ now += TimeUnit.MILLISECONDS.toNanos(11);
+ }
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testTrackerSteadyAboveRate(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ boolean exceeded = false;
+ for (int sample = 0; sample < 2000; sample++)
+ {
+ exceeded = !tracker.onRequest(now);
+ if (exceeded)
+ break;
+ now += TimeUnit.MILLISECONDS.toNanos(9);
+ }
+
+ assertTrue(exceeded);
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testTrackerUnevenBelowRate(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ for (int sample = 0; sample < 20; sample++)
+ {
+ for (int burst = 0; burst < 9; burst++)
+ {
+ boolean exceeded = !tracker.onRequest(now);
+ assertFalse(exceeded);
+ }
+
+ now += TimeUnit.MILLISECONDS.toNanos(100);
+ }
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testTrackerUnevenAboveRate(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ boolean exceeded = false;
+ loop: for (int sample = 0; sample < 200; sample++)
+ {
+ for (int burst = 0; burst < 11; burst++)
+ {
+ exceeded = !tracker.onRequest(now);
+ if (exceeded)
+ break loop;
+ }
+
+ now += TimeUnit.MILLISECONDS.toNanos(100);
+ }
+
+ assertTrue(exceeded);
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testTrackerBurstBelowRate(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ for (int seconds = 0; seconds < 2; seconds++)
+ {
+ for (int burst = 0; burst < 99; burst++)
+ {
+ boolean exceeded = !tracker.onRequest(now++);
+ assertFalse(exceeded);
+ }
+ now += TimeUnit.MILLISECONDS.toNanos(1000);
+ }
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testTrackerBurstAboveRate(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ boolean exceeded = false;
+ for (int seconds = 0; seconds < 2; seconds++)
+ {
+ for (int burst = 0; burst < 101; burst++)
+ {
+ if (!tracker.onRequest(now))
+ {
+ exceeded = true;
+ break;
+ }
+ }
+
+ now += TimeUnit.MILLISECONDS.toNanos(1000);
+ }
+
+ assertTrue(exceeded);
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testRecoveryAfterBursts(DoSHandler.Tracker.Factory factory)
+ {
+ Server server = new Server();
+ DoSHandler handler = new DoSHandler(factory);
+ server.setHandler(handler);
+ LifeCycle.start(server);
+ DoSHandler.Tracker tracker = handler.newTracker("id");
+ long now = System.nanoTime();
+
+ boolean exceeded = false;
+ for (int burst = 0; burst < 1000; burst++)
+ {
+ now += TimeUnit.MILLISECONDS.toNanos(75);
+ exceeded = !tracker.onRequest(now);
+ }
+
+ for (int burst = 0; !exceeded && burst < 1000; burst++)
+ {
+ exceeded = !tracker.onRequest(now++);
+ }
+ assertTrue(exceeded);
+
+ exceeded = false;
+ for (int burst = 0; burst < 1000; burst++)
+ {
+ now += TimeUnit.MILLISECONDS.toNanos(75);
+ exceeded = !tracker.onRequest(now);
+ }
+ assertFalse(exceeded);
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testOKRequestRate(DoSHandler.Tracker.Factory factory) throws Exception
+ {
+ Server server = new Server();
+ LocalConnector connector = new LocalConnector(server);
+ server.addConnector(connector);
+
+ DoSHandler dosHandler = new DoSHandler(factory);
+ DumpHandler dumpHandler = new DumpHandler();
+ server.setHandler(dosHandler);
+ dosHandler.setHandler(dumpHandler);
+
+ server.start();
+
+ long now = System.nanoTime();
+ long end = now + TimeUnit.SECONDS.toNanos(5);
+ CountDownLatch latch = new CountDownLatch(90);
+ for (int thread = 0; thread < 90; thread++)
+ {
+ server.getThreadPool().execute(() ->
+ {
+ try
+ {
+ while (NanoTime.isBefore(NanoTime.now(), end))
+ {
+ String response = connector.getResponse("""
+ GET / HTTP/1.1\r
+ Host: local\r
+
+ """);
+ assertThat(response, containsString("200 OK"));
+ Thread.sleep(1000);
+ }
+ latch.countDown();
+ }
+ catch (Throwable x)
+ {
+ throw new RuntimeException(x);
+ }
+ });
+ }
+
+ assertTrue(latch.await(10, TimeUnit.SECONDS));
+ }
+
+ @ParameterizedTest
+ @MethodSource("factories")
+ public void testHighRequestRate(DoSHandler.Tracker.Factory factory) throws Exception
+ {
+ Server server = new Server();
+ LocalConnector connector = new LocalConnector(server);
+ server.addConnector(connector);
+
+ DoSHandler dosHandler = new DoSHandler(factory);
+ DumpHandler dumpHandler = new DumpHandler();
+ server.setHandler(dosHandler);
+ dosHandler.setHandler(dumpHandler);
+
+ server.start();
+
+ long now = System.nanoTime();
+ long end = now + TimeUnit.SECONDS.toNanos(5);
+ AtomicInteger outstanding = new AtomicInteger(0);
+ AtomicInteger calm = new AtomicInteger();
+ for (int thread = 0; thread < 120; thread++)
+ {
+ server.getThreadPool().execute(() ->
+ {
+ try
+ {
+ while (NanoTime.isBefore(NanoTime.now(), end))
+ {
+ try
+ {
+ outstanding.incrementAndGet();
+ String response = connector.getResponse("""
+ GET / HTTP/1.1\r
+ Host: local\r
+
+ """);
+ if (response.contains(" 429 "))
+ calm.incrementAndGet();
+ Thread.sleep(1000);
+ }
+ finally
+ {
+ outstanding.decrementAndGet();
+ }
+ }
+ }
+ catch (Throwable x)
+ {
+ throw new RuntimeException(x);
+ }
+ });
+ }
+
+ Awaitility.waitAtMost(10, TimeUnit.SECONDS).until(() -> outstanding.get() == 0);
+ assertThat(calm.get(), greaterThan(0));
+ }
+}