diff --git a/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java b/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java index 1905a8d76405..ef7efecaec89 100644 --- a/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java +++ b/jetty-core/jetty-io/src/main/java/org/eclipse/jetty/io/CyclicTimeouts.java @@ -190,7 +190,7 @@ public interface Expirable * * @return the expiration time in nanoseconds, or {@link Long#MAX_VALUE} if this entity does not expire */ - public long getExpireNanoTime(); + long getExpireNanoTime(); } private class Timeouts extends CyclicTimeout diff --git a/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml b/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml new file mode 100644 index 000000000000..68b4b9de0087 --- /dev/null +++ b/jetty-core/jetty-server/src/main/config/etc/jetty-dos.xml @@ -0,0 +1,75 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/jetty-core/jetty-server/src/main/config/modules/dos.mod b/jetty-core/jetty-server/src/main/config/modules/dos.mod new file mode 100644 index 000000000000..45d5a9f5da4d --- /dev/null +++ b/jetty-core/jetty-server/src/main/config/modules/dos.mod @@ -0,0 +1,62 @@ +# DO NOT EDIT THIS FILE - See: https://eclipse.dev/jetty/documentation/ + +[description] +Enables the DosHandler for the server. + +[tags] +connector + +[depend] +server + +[xml] +etc/jetty-dos.xml + +[ini-template] + +## The algorithm to use for obtaining an remote client identifier from a Request: ID_FROM_REMOTE_ADDRESS, ID_FROM_REMOTE_PORT, ID_FROM_REMOTE_ADDRESS_PORT, ID_FROM_CONNECTION +#jetty.dos.id.type=ID_FROM_REMOTE_ADDRESS +#jetty.dos.id.class=org.eclipse.jetty.server.handler.DosHandler + +## The class to use to create Tracker instances to track the rate of requests +#jetty.dos.trackerFactory=org.eclipse.jetty.server.handler.DoSHandler$LeakingBucketTrackerFactory + +## The Handler class to use to reject DOS requests +#jetty.dos.rejectHandler=org.eclipse.jetty.server.handler.DoSHandler$DelayedRejectHandler + +## The maximum requests per second per client +#jetty.dos.leakingBucketTracker.maxRequestsPerSecond=100 + +## The size of the leaky bucket. Larger buckets allow longer bursts before enforcing the rate +#jetty.dos.leakingBucketTracker.bucketSize=100 + +## The time in seconds to retain an empty bucket. +#jetty.dos.leakingBucketTracker.idleTimeout=1 + +## The period to delay dos requests before rejecting them. +#jetty.dos.rejectHandler.delayed.delayMs=1000 + +## The maximum number of requests to be held in the delay queue +#jetty.dos.rejectHandler.delayed.maxDelayQueue=1000 + +## The maximum number of clients to track; or -1 for a default value; or 0 for unlimited +#jetty.dos.maxTrackers=100000 + +## Should untracked requests (due to maxTrackers) be rejected or allowed +#jetty.dos.rejectUntracked=false + +## The status code used to reject requests; or 0 to abort the request; or -1 for a default +#jetty.dos.rejectStatus=429 + +## List of InetAddress patterns to include +#jetty.dos.include.inet=10.10.10-14.0-128 + +## List of InetAddressPatterns to exclude +#jetty.dos.exclude.inet=10.10.10-14.0-128 + +## List of path patterns to include +#jetty.dos.include.path=/context/* + +## List of path to exclude +#jetty.dos.exclude.path=/context/* + diff --git a/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java b/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java new file mode 100644 index 000000000000..5aaa5d2be2a6 --- /dev/null +++ b/jetty-core/jetty-server/src/main/java/org/eclipse/jetty/server/handler/DoSHandler.java @@ -0,0 +1,522 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.server.handler; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import org.eclipse.jetty.http.HttpStatus; +import org.eclipse.jetty.io.CyclicTimeouts; +import org.eclipse.jetty.server.ConnectionMetaData; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.util.Callback; +import org.eclipse.jetty.util.NanoTime; +import org.eclipse.jetty.util.annotation.ManagedObject; +import org.eclipse.jetty.util.annotation.Name; +import org.eclipse.jetty.util.thread.AutoLock; +import org.eclipse.jetty.util.thread.Scheduler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + *

A Denial of Service Handler that protects from attacks by limiting the request rate from remote clients.

+ */ +@ManagedObject("DoS Prevention Handler") +public class DoSHandler extends ConditionalHandler.ElseNext +{ + private static final Logger LOG = LoggerFactory.getLogger(DoSHandler.class); + + /** + * A {@link Function} to create a remote client identifier from the remote address and remote port of a {@link Request}. + */ + public static final Function ID_FROM_REMOTE_ADDRESS_PORT = request -> + { + SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress(); + if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress) + return inetSocketAddress.toString(); + return remoteSocketAddress.toString(); + }; + + /** + * A {@link Function} to create a remote client identifier from the remote address of a {@link Request}. + */ + public static final Function ID_FROM_REMOTE_ADDRESS = request -> + { + SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress(); + if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress) + return inetSocketAddress.getAddress().toString(); + return remoteSocketAddress.toString(); + }; + + /** + * A {@link Function} to create a remote client identifier from the remote port of a {@link Request}. + * Useful if there is an untrusted intermediary, where the remote port can be a surrogate for the connection. + */ + public static final Function ID_FROM_REMOTE_PORT = request -> + { + SocketAddress remoteSocketAddress = request.getConnectionMetaData().getRemoteSocketAddress(); + if (remoteSocketAddress instanceof InetSocketAddress inetSocketAddress) + return Integer.toString(inetSocketAddress.getPort()); + return remoteSocketAddress.toString(); + }; + + /** + * A {@link Function} to create a remote client identifier from {@link ConnectionMetaData#getId()} of a {@link Request}. + */ + public static final Function ID_FROM_CONNECTION = request -> request.getConnectionMetaData().getId(); + + private final Map _trackers = new ConcurrentHashMap<>(); + private final Function _clientIdFn; + private final Tracker.Factory _trackerFactory; + private final Request.Handler _rejectHandler; + private final int _maxTrackers; + private final boolean _rejectUntracked; + private CyclicTimeouts _cyclicTimeouts; + + /** + * @param trackerFactory Factory to create a Tracker + */ + public DoSHandler(@Name("trackerFactory") Tracker.Factory trackerFactory) + { + this(null, trackerFactory, null, -1); + } + + /** + * @param clientIdFn Function to extract a remote client identifier from a request. + * @param trackerFactory Factory to create a Tracker + * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default. + * @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited. + * If this limit is exceeded, then requests from additional remote clients are rejected. + */ + public DoSHandler( + @Name("clientIdFn") Function clientIdFn, + @Name("trackerFactory") Tracker.Factory trackerFactory, + @Name("rejectHandler") Request.Handler rejectHandler, + @Name("maxTrackers") int maxTrackers) + { + this(null, clientIdFn, trackerFactory, rejectHandler, maxTrackers); + } + + /** + * @param handler Then next {@link Handler} or {@code null} + * @param clientIdFn Function to extract a remote client identifier from a request. + * @param trackerFactory Factory to create a Tracker + * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default. + * @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited. + * If this limit is exceeded, then requests from additional remote clients are rejected. + */ + public DoSHandler( + @Name("handler") Handler handler, + @Name("clientIdFn") Function clientIdFn, + @Name("trackerFactory") Tracker.Factory trackerFactory, + @Name("rejectHandler") Request.Handler rejectHandler, + @Name("maxTrackers") int maxTrackers) + { + this(handler, clientIdFn, trackerFactory, rejectHandler, maxTrackers, false); + } + + /** + * @param handler Then next {@link Handler} or {@code null} + * @param clientIdFn Function to extract a remote client identifier from a request. + * @param trackerFactory Factory to create a Tracker + * @param rejectHandler A {@link Handler} used to reject excess requests, or {@code null} for a default. + * @param maxTrackers The maximum number of remote clients to track or -1 for a default value, 0 for unlimited. + * If this limit is exceeded, then requests from additional remote clients are rejected. + */ + public DoSHandler( + @Name("handler") Handler handler, + @Name("clientIdFn") Function clientIdFn, + @Name("trackerFactory") Tracker.Factory trackerFactory, + @Name("rejectHandler") Request.Handler rejectHandler, + @Name("maxTrackers") int maxTrackers, + @Name("rejectUntracked") boolean rejectUntracked) + { + super(handler); + installBean(_trackers); + _clientIdFn = Objects.requireNonNullElse(clientIdFn, ID_FROM_REMOTE_ADDRESS); + installBean(_clientIdFn); + _trackerFactory = Objects.requireNonNull(trackerFactory); + installBean(_trackerFactory); + // default max trackers to a large, effectively infinite number, but ultimately bounded. + _maxTrackers = maxTrackers < 0 ? 100_000 : maxTrackers; + _rejectHandler = Objects.requireNonNullElseGet(rejectHandler, StatusRejectHandler::new); + installBean(_rejectHandler); + _rejectUntracked = rejectUntracked; + } + + @Override + public void setServer(Server server) + { + super.setServer(server); + if (_rejectHandler instanceof Handler handler) + handler.setServer(server); + } + + @Override + protected boolean onConditionsMet(Request request, Response response, Callback callback) throws Exception + { + + // Calculate an id for the request (which may be global empty string). + String id = _clientIdFn.apply(request); + + // Reject or handle untracked request + if (id == null) + id = ""; + + // Obtain a tracker, creating a new one if necessary (and not too many) + // Trackers are removed if CyclicTimeouts#onExpired returns true. + Tracker tracker = _trackers.computeIfAbsent(id, this::newTracker); + + // If we have too many trackers, then we will have a null tracker + if (tracker == null) + return _rejectUntracked ? _rejectHandler.handle(request, response, callback) : nextHandler(request, response, callback); + + // IS the request allowed by the tracker? + boolean allowed = tracker.onRequest(NanoTime.now()); + if (LOG.isDebugEnabled()) + LOG.debug("allowed={} {}", allowed, tracker); + if (allowed) + return nextHandler(request, response, callback); + + // Otherwise reject the request as it is over rate + return _rejectHandler.handle(request, response, callback); + } + + Tracker newTracker(String id) + { + if (_maxTrackers > 0 && _trackers.size() >= _maxTrackers) + return null; + + Tracker tracker = _trackerFactory.newTracker(id); + _cyclicTimeouts.schedule(tracker); + return tracker; + } + + @Override + protected void doStart() throws Exception + { + _cyclicTimeouts = new CyclicTimeouts<>(getServer().getScheduler()) + { + @Override + protected Iterator iterator() + { + return _trackers.values().iterator(); + } + + @Override + protected boolean onExpired(Tracker tracker) + { + return true; + } + }; + addBean(_cyclicTimeouts); + super.doStart(); + } + + @Override + protected void doStop() throws Exception + { + super.doStop(); + removeBean(_cyclicTimeouts); + _cyclicTimeouts.destroy(); + _cyclicTimeouts = null; + } + + /** + * A RateTracker is associated with an id, and stores request rate data. + */ + public interface Tracker extends CyclicTimeouts.Expirable + { + /** + * Add a request to the tracker and check the rate limit + * + * @param now The timestamp of the request + * @return {@code true} if the request is below the limit + */ + boolean onRequest(long now); + + interface Factory + { + Tracker newTracker(String id); + } + } + + /** + * The Tracker implements the classic Leaky Bucket Algorithm. + */ + public static class LeakingBucketTrackerFactory implements Tracker.Factory + { + private final int _maxRequestsPerSecond; + private final int _bucketSize; + private final long _nanosPerDrip; + private final long _idleTimeout; + + /** + * @param maxRequestsPerSecond the maximum requests per second allowed by this tracker + */ + public LeakingBucketTrackerFactory( + @Name("maxRequestsPerSecond") int maxRequestsPerSecond) + { + this(maxRequestsPerSecond, -1, null); + } + + /** + * @param maxRequestsPerSecond the maximum requests per second allowed by this tracker + * @param bucketSize the size of the bucket in request/drips, which is effectively the burst capacity, giving the number + * of request that can be handled in excess of the short term rate, before being rejected. + * Use -1 for a heuristic value. + * @param idleTimeout The period to keep an empty bucket before removal, or null to remove a bucket immediately when + * empty + */ + public LeakingBucketTrackerFactory( + @Name("maxRequestsPerSecond") int maxRequestsPerSecond, + @Name("bucketSize") int bucketSize, + @Name("idleTimeout") Duration idleTimeout) + { + _maxRequestsPerSecond = maxRequestsPerSecond; + _nanosPerDrip = TimeUnit.SECONDS.toNanos(1) / _maxRequestsPerSecond; + _bucketSize = (bucketSize < 0) + ? _maxRequestsPerSecond + : bucketSize; + _idleTimeout = idleTimeout == null ? 0 : idleTimeout.toNanos(); + } + + @Override + public Tracker newTracker(String id) + { + return new LeakingBucketTracker(id); + } + + private class LeakingBucketTracker implements Tracker + { + private final AutoLock _lock = new AutoLock(); + private final String _id; + private long _lastDripNanoTime; + private long _expireNanoTime; + private int _bucket; + + public LeakingBucketTracker(String id) + { + _id = id; + long now = NanoTime.now(); + _lastDripNanoTime = now; + _expireNanoTime = now + _nanosPerDrip + _idleTimeout; + } + + @Override + public long getExpireNanoTime() + { + try (AutoLock ignored = _lock.lock()) + { + return _expireNanoTime; + } + } + + @Override + public boolean onRequest(long now) + { + try (AutoLock ignored = _lock.lock()) + { + long elapsedSinceLastDrip = NanoTime.elapsed(_lastDripNanoTime, now); + long drips = elapsedSinceLastDrip / _nanosPerDrip; + _lastDripNanoTime = _lastDripNanoTime + drips * _nanosPerDrip; + _bucket = Math.min(_bucketSize, Math.toIntExact(Math.max(0L, _bucket - drips) + 1)); + _expireNanoTime = now + _bucket * _nanosPerDrip + _idleTimeout; + return _bucket < _bucketSize; + } + } + + @Override + public String toString() + { + try (AutoLock ignored = _lock.lock()) + { + return "%s@%s{%d/%d}".formatted(getClass().getSimpleName(), _id, _bucket, _maxRequestsPerSecond); + } + } + } + } + + /** + * A Handler to reject DoS requests with a status code or failure. + */ + public static class StatusRejectHandler implements Request.Handler + { + private final int _status; + + public StatusRejectHandler() + { + this(-1); + } + + /** + * @param status The status used to reject a request, or 0 to fail the request or -1 for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}. + */ + public StatusRejectHandler(int status) + { + _status = status >= 0 ? status : HttpStatus.TOO_MANY_REQUESTS_429; + if (_status != 0 && _status != HttpStatus.OK_200 && !HttpStatus.isClientError(_status) && !HttpStatus.isServerError(_status)) + throw new IllegalArgumentException("status must be a client or server error"); + } + + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception + { + if (_status == 0) + callback.failed(new RejectedExecutionException()); + else + Response.writeError(request, response, callback, _status); + return true; + } + } + + /** + * A Handler to reject DoS requests after first delaying them. + */ + public static class DelayedRejectHandler extends Handler.Abstract + { + private record Exchange(Request request, Response response, Callback callback) + { + } + + private final AutoLock _lock = new AutoLock(); + private final Deque _delayQueue = new ArrayDeque<>(); + private final int _maxDelayQueue; + private final long _delayMs; + private final Request.Handler _reject; + private Scheduler _scheduler; + + public DelayedRejectHandler() + { + this(-1, -1, null); + } + + /** + * @param delayMs The delay in milliseconds to hold rejected requests before sending a response or -1 for a default (1000ms) + * @param maxDelayQueue The maximum number of delayed requests to hold or -1 for a default (1000ms). + * @param reject The {@link Request.Handler} used to reject {@link Request}s or null for a default ({@link HttpStatus#TOO_MANY_REQUESTS_429}). + */ + public DelayedRejectHandler( + @Name("delayMs") long delayMs, + @Name("maxDelayQueue") int maxDelayQueue, + @Name("reject") Request.Handler reject) + { + _delayMs = delayMs >= 0 ? delayMs : 1000; + _maxDelayQueue = maxDelayQueue >= 0 ? maxDelayQueue : 1000; + _reject = Objects.requireNonNullElseGet(reject, () -> new StatusRejectHandler(HttpStatus.TOO_MANY_REQUESTS_429)); + } + + @Override + protected void doStart() throws Exception + { + super.doStart(); + _scheduler = getServer().getScheduler(); + addBean(_scheduler); + } + + @Override + protected void doStop() throws Exception + { + super.doStop(); + removeBean(_scheduler); + _scheduler = null; + } + + @Override + public boolean handle(Request request, Response response, Callback callback) throws Exception + { + List rejects = null; + try (AutoLock ignored = _lock.lock()) + { + while (_delayQueue.size() >= _maxDelayQueue) + { + Exchange exchange = _delayQueue.removeFirst(); + if (rejects == null) + rejects = new ArrayList<>(); + rejects.add(exchange); + } + + if (_delayQueue.isEmpty()) + _scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS); + _delayQueue.addLast(new Exchange(request, response, callback)); + } + + reject(rejects); + + return true; + } + + private void onTick() + { + long expired = NanoTime.now() - TimeUnit.MILLISECONDS.toNanos(_delayMs); + + List rejects = null; + try (AutoLock ignored = _lock.lock()) + { + Iterator iterator = _delayQueue.iterator(); + while (iterator.hasNext()) + { + Exchange exchange = iterator.next(); + if (NanoTime.isBeforeOrSame(exchange.request.getBeginNanoTime(), expired)) + { + iterator.remove(); + + if (rejects == null) + rejects = new ArrayList<>(); + rejects.add(exchange); + } + } + + if (!_delayQueue.isEmpty()) + _scheduler.schedule(this::onTick, _delayMs / 2, TimeUnit.MILLISECONDS); + } + + reject(rejects); + } + + private void reject(List rejects) + { + if (rejects != null) + { + for (Exchange exchange : rejects) + { + try + { + if (!_reject.handle(exchange.request, exchange.response, exchange.callback)) + exchange.callback.failed(new RejectedExecutionException()); + } + catch (Throwable t) + { + exchange.callback.failed(t); + } + } + } + } + } +} diff --git a/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java b/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java new file mode 100644 index 000000000000..8dd8e2ae249a --- /dev/null +++ b/jetty-core/jetty-server/src/test/java/org/eclipse/jetty/server/handler/DoSHandlerTest.java @@ -0,0 +1,322 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.server.handler; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import org.awaitility.Awaitility; +import org.eclipse.jetty.server.LocalConnector; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.util.NanoTime; +import org.eclipse.jetty.util.component.LifeCycle; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DoSHandlerTest +{ + public static Stream factories() + { + return Stream.of( + Arguments.of(new DoSHandler.LeakingBucketTrackerFactory(100)) + ); + } + + @ParameterizedTest + @MethodSource("factories") + public void testTrackerSteadyBelowRate(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + for (int sample = 0; sample < 400; sample++) + { + boolean exceeded = !tracker.onRequest(now); + assertFalse(exceeded); + now += TimeUnit.MILLISECONDS.toNanos(11); + } + } + + @ParameterizedTest + @MethodSource("factories") + public void testTrackerSteadyAboveRate(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + boolean exceeded = false; + for (int sample = 0; sample < 2000; sample++) + { + exceeded = !tracker.onRequest(now); + if (exceeded) + break; + now += TimeUnit.MILLISECONDS.toNanos(9); + } + + assertTrue(exceeded); + } + + @ParameterizedTest + @MethodSource("factories") + public void testTrackerUnevenBelowRate(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + for (int sample = 0; sample < 20; sample++) + { + for (int burst = 0; burst < 9; burst++) + { + boolean exceeded = !tracker.onRequest(now); + assertFalse(exceeded); + } + + now += TimeUnit.MILLISECONDS.toNanos(100); + } + } + + @ParameterizedTest + @MethodSource("factories") + public void testTrackerUnevenAboveRate(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + boolean exceeded = false; + loop: for (int sample = 0; sample < 200; sample++) + { + for (int burst = 0; burst < 11; burst++) + { + exceeded = !tracker.onRequest(now); + if (exceeded) + break loop; + } + + now += TimeUnit.MILLISECONDS.toNanos(100); + } + + assertTrue(exceeded); + } + + @ParameterizedTest + @MethodSource("factories") + public void testTrackerBurstBelowRate(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + for (int seconds = 0; seconds < 2; seconds++) + { + for (int burst = 0; burst < 99; burst++) + { + boolean exceeded = !tracker.onRequest(now++); + assertFalse(exceeded); + } + now += TimeUnit.MILLISECONDS.toNanos(1000); + } + } + + @ParameterizedTest + @MethodSource("factories") + public void testTrackerBurstAboveRate(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + boolean exceeded = false; + for (int seconds = 0; seconds < 2; seconds++) + { + for (int burst = 0; burst < 101; burst++) + { + if (!tracker.onRequest(now)) + { + exceeded = true; + break; + } + } + + now += TimeUnit.MILLISECONDS.toNanos(1000); + } + + assertTrue(exceeded); + } + + @ParameterizedTest + @MethodSource("factories") + public void testRecoveryAfterBursts(DoSHandler.Tracker.Factory factory) + { + Server server = new Server(); + DoSHandler handler = new DoSHandler(factory); + server.setHandler(handler); + LifeCycle.start(server); + DoSHandler.Tracker tracker = handler.newTracker("id"); + long now = System.nanoTime(); + + boolean exceeded = false; + for (int burst = 0; burst < 1000; burst++) + { + now += TimeUnit.MILLISECONDS.toNanos(75); + exceeded = !tracker.onRequest(now); + } + + for (int burst = 0; !exceeded && burst < 1000; burst++) + { + exceeded = !tracker.onRequest(now++); + } + assertTrue(exceeded); + + exceeded = false; + for (int burst = 0; burst < 1000; burst++) + { + now += TimeUnit.MILLISECONDS.toNanos(75); + exceeded = !tracker.onRequest(now); + } + assertFalse(exceeded); + } + + @ParameterizedTest + @MethodSource("factories") + public void testOKRequestRate(DoSHandler.Tracker.Factory factory) throws Exception + { + Server server = new Server(); + LocalConnector connector = new LocalConnector(server); + server.addConnector(connector); + + DoSHandler dosHandler = new DoSHandler(factory); + DumpHandler dumpHandler = new DumpHandler(); + server.setHandler(dosHandler); + dosHandler.setHandler(dumpHandler); + + server.start(); + + long now = System.nanoTime(); + long end = now + TimeUnit.SECONDS.toNanos(5); + CountDownLatch latch = new CountDownLatch(90); + for (int thread = 0; thread < 90; thread++) + { + server.getThreadPool().execute(() -> + { + try + { + while (NanoTime.isBefore(NanoTime.now(), end)) + { + String response = connector.getResponse(""" + GET / HTTP/1.1\r + Host: local\r + + """); + assertThat(response, containsString("200 OK")); + Thread.sleep(1000); + } + latch.countDown(); + } + catch (Throwable x) + { + throw new RuntimeException(x); + } + }); + } + + assertTrue(latch.await(10, TimeUnit.SECONDS)); + } + + @ParameterizedTest + @MethodSource("factories") + public void testHighRequestRate(DoSHandler.Tracker.Factory factory) throws Exception + { + Server server = new Server(); + LocalConnector connector = new LocalConnector(server); + server.addConnector(connector); + + DoSHandler dosHandler = new DoSHandler(factory); + DumpHandler dumpHandler = new DumpHandler(); + server.setHandler(dosHandler); + dosHandler.setHandler(dumpHandler); + + server.start(); + + long now = System.nanoTime(); + long end = now + TimeUnit.SECONDS.toNanos(5); + AtomicInteger outstanding = new AtomicInteger(0); + AtomicInteger calm = new AtomicInteger(); + for (int thread = 0; thread < 120; thread++) + { + server.getThreadPool().execute(() -> + { + try + { + while (NanoTime.isBefore(NanoTime.now(), end)) + { + try + { + outstanding.incrementAndGet(); + String response = connector.getResponse(""" + GET / HTTP/1.1\r + Host: local\r + + """); + if (response.contains(" 429 ")) + calm.incrementAndGet(); + Thread.sleep(1000); + } + finally + { + outstanding.decrementAndGet(); + } + } + } + catch (Throwable x) + { + throw new RuntimeException(x); + } + }); + } + + Awaitility.waitAtMost(10, TimeUnit.SECONDS).until(() -> outstanding.get() == 0); + assertThat(calm.get(), greaterThan(0)); + } +}