Skip to content

Commit

Permalink
Refactor RateLimiter (#534)
Browse files Browse the repository at this point in the history
This PR is motivated by the following issues:

* `RateLimiter` implementations are typically request-scoped, because they are sensitive to the realm context,
  but keeping a cache of token buckets per realm must be delegated to a separate, application-scoped bean: this
  is solved by introducing a new `TokenBucketFactory` bean.
* `TokenBucketRateLimiter` should NOT implement `RateLimiter`, because that would result in two available beans
  for this interface: the one selected by configuration (`realm-token-bucket` or `no-op`), and this one which
  is always available, thus triggering an unresolvable bean error when using the Quarkus runtime. Instead,
  this class is now just a general-purpose Token Bucket implementation.

This PR has one user-facing change:

* The configuration options for the `realm-token-bucket` rate limiter were moved from the `rateLimiter` section to
  the `tokenBucketFactory` section.
  • Loading branch information
adutra authored Dec 17, 2024
1 parent b7fa1f4 commit 1071aa2
Show file tree
Hide file tree
Showing 16 changed files with 170 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.polaris.service.context.CallContextResolver;
import org.apache.polaris.service.context.RealmContextResolver;
import org.apache.polaris.service.ratelimiter.RateLimiter;
import org.apache.polaris.service.ratelimiter.TokenBucketFactory;
import org.apache.polaris.service.storage.PolarisStorageIntegrationProviderImpl;
import org.apache.polaris.service.types.TokenType;
import org.glassfish.hk2.api.Factory;
Expand Down Expand Up @@ -90,6 +91,7 @@ public class PolarisApplicationConfig extends Configuration {
private String awsSecretKey;
private FileIOFactory fileIOFactory;
private RateLimiter rateLimiter;
private TokenBucketFactory tokenBucketFactory;
private TokenBrokerFactory tokenBrokerFactory;

private AccessToken gcpAccessToken;
Expand Down Expand Up @@ -144,6 +146,9 @@ protected void configure() {
bindFactory(SupplierFactory.create(serviceLocator, config::getRateLimiter))
.to(RateLimiter.class)
.ranked(OVERRIDE_BINDING_RANK);
bindFactory(SupplierFactory.create(serviceLocator, config::getTokenBucketFactory))
.to(TokenBucketFactory.class)
.ranked(OVERRIDE_BINDING_RANK);
}
};
}
Expand Down Expand Up @@ -332,6 +337,17 @@ public void setRateLimiter(@Nullable RateLimiter rateLimiter) {
this.rateLimiter = rateLimiter;
}

@JsonProperty("tokenBucketFactory")
private TokenBucketFactory getTokenBucketFactory() {
return tokenBucketFactory;
}

@JsonProperty("tokenBucketFactory")
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
public void setTokenBucketFactory(@Nullable TokenBucketFactory tokenBucketFactory) {
this.tokenBucketFactory = tokenBucketFactory;
}

public void setTaskHandler(TaskHandlerConfiguration taskHandler) {
this.taskHandler = taskHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=no-op
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.ratelimiter.TokenBucketRateLimiter]S
contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=token-bucket
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.ratelimiter.RealmTokenBucketRateLimiter]S
contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=realm-token-bucket
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.ratelimiter.DefaultTokenBucketFactory]S
contract={org.apache.polaris.service.ratelimiter.TokenBucketFactory}
name=default
qualifier={io.smallrye.common.annotation.Identifier}

Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.smallrye.common.annotation.Identifier;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneOffset;
import org.apache.polaris.service.ratelimiter.RealmTokenBucketRateLimiter;
import org.apache.polaris.service.ratelimiter.DefaultTokenBucketFactory;
import org.threeten.extra.MutableClock;

/** RealmTokenBucketRateLimiter with a mock clock */
@Identifier("mock-realm-token-bucket")
public class MockRealmTokenBucketRateLimiter extends RealmTokenBucketRateLimiter {
/** TokenBucketFactory with a mock clock */
@Identifier("mock")
public class MockTokenBucketFactory extends DefaultTokenBucketFactory {
public static MutableClock CLOCK = MutableClock.of(Instant.now(), ZoneOffset.UTC);

@JsonCreator
public MockRealmTokenBucketRateLimiter(
@JsonProperty("requestsPerSecond") final long requestsPerSecond,
@JsonProperty("windowSeconds") final long windowSeconds) {
super(requestsPerSecond, windowSeconds);
}

@Override
protected Clock getClock() {
return CLOCK;
public MockTokenBucketFactory(
@JsonProperty("requestsPerSecond") long requestsPerSecond,
@JsonProperty("windowSeconds") long windowSeconds) {
super(requestsPerSecond, windowSeconds, CLOCK);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ public class RateLimiterFilterTest {
"server.applicationConnectors[0].port",
"0"), // Bind to random port to support parallelism
ConfigOverride.config("server.adminConnectors[0].port", "0"),
ConfigOverride.config("rateLimiter.type", "mock-realm-token-bucket"),
ConfigOverride.config("tokenBucketFactory.type", "mock"),
ConfigOverride.config(
"rateLimiter.requestsPerSecond", String.valueOf(REQUESTS_PER_SECOND)),
ConfigOverride.config("rateLimiter.windowSeconds", String.valueOf(WINDOW_SECONDS)));
"tokenBucketFactory.requestsPerSecond", String.valueOf(REQUESTS_PER_SECOND)),
ConfigOverride.config(
"tokenBucketFactory.windowSeconds", String.valueOf(WINDOW_SECONDS)));

private static String userToken;
private static String realm;
private static MutableClock clock = MockRealmTokenBucketRateLimiter.CLOCK;
private static MutableClock clock = MockTokenBucketFactory.CLOCK;

@BeforeAll
public static void setup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,49 @@
*/
package org.apache.polaris.service.dropwizard.ratelimiter;

import static org.apache.polaris.service.dropwizard.ratelimiter.MockTokenBucketFactory.CLOCK;

import java.time.Duration;
import org.apache.polaris.core.context.CallContext;
import org.apache.polaris.service.ratelimiter.RateLimiter;
import org.apache.polaris.service.ratelimiter.DefaultTokenBucketFactory;
import org.apache.polaris.service.ratelimiter.RealmTokenBucketRateLimiter;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.threeten.extra.MutableClock;

/** Main unit test class for TokenBucketRateLimiter */
public class RealmTokenBucketRateLimiterTest {
@Test
void testDifferentBucketsDontTouch() {
RateLimiter rateLimiter = new MockRealmTokenBucketRateLimiter(10, 10);
RateLimitResultAsserter asserter = new RateLimitResultAsserter(rateLimiter);
MutableClock clock = MockRealmTokenBucketRateLimiter.CLOCK;
RealmTokenBucketRateLimiter rateLimiter = new RealmTokenBucketRateLimiter();
rateLimiter.setTokenBucketFactory(new DefaultTokenBucketFactory(10, 10, CLOCK));

for (int i = 0; i < 202; i++) {
String realm = (i % 2 == 0) ? "realm1" : "realm2";
CallContext.setCurrentContext(CallContext.of(() -> realm, null));

if (i < 200) {
asserter.canAcquire(1);
Assertions.assertTrue(rateLimiter.canProceed());
} else {
asserter.cantAcquire();
assertCannotProceed(rateLimiter);
}
}

clock.add(Duration.ofSeconds(1));
CLOCK.add(Duration.ofSeconds(1));
for (int i = 0; i < 22; i++) {
String realm = (i % 2 == 0) ? "realm1" : "realm2";
CallContext.setCurrentContext(CallContext.of(() -> realm, null));

if (i < 20) {
asserter.canAcquire(1);
Assertions.assertTrue(rateLimiter.canProceed());
} else {
asserter.cantAcquire();
assertCannotProceed(rateLimiter);
}
}
}

private void assertCannotProceed(RealmTokenBucketRateLimiter rateLimiter) {
for (int i = 0; i < 5; i++) {
Assertions.assertFalse(rateLimiter.canProceed());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.polaris.service.ratelimiter.TokenBucketRateLimiter;
import org.apache.polaris.service.ratelimiter.TokenBucket;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.threeten.extra.MutableClock;
Expand All @@ -38,19 +38,18 @@ void testBasic() {
MutableClock clock = MutableClock.of(Instant.now(), ZoneOffset.UTC);
clock.add(Duration.ofSeconds(5));

RateLimitResultAsserter asserter =
new RateLimitResultAsserter(new TokenBucketRateLimiter(10, 100, clock));
TokenBucket tokenBucket = new TokenBucket(10, 100, clock);

asserter.canAcquire(100);
asserter.cantAcquire();
assertCanAcquire(tokenBucket, 100);
assertCannotAcquire(tokenBucket);

clock.add(Duration.ofSeconds(1));
asserter.canAcquire(10);
asserter.cantAcquire();
assertCanAcquire(tokenBucket, 10);
assertCannotAcquire(tokenBucket);

clock.add(Duration.ofSeconds(10));
asserter.canAcquire(100);
asserter.cantAcquire();
assertCanAcquire(tokenBucket, 100);
assertCannotAcquire(tokenBucket);
}

/**
Expand All @@ -63,9 +62,8 @@ void testConcurrent() throws InterruptedException {
int numTasks = 50000;
int tokensPerSecond = 10; // Can be anything above 0

TokenBucketRateLimiter rl =
new TokenBucketRateLimiter(
tokensPerSecond, maxTokens, Clock.fixed(Instant.now(), ZoneOffset.UTC));
TokenBucket rl =
new TokenBucket(tokensPerSecond, maxTokens, Clock.fixed(Instant.now(), ZoneOffset.UTC));
AtomicInteger numAcquired = new AtomicInteger();
CountDownLatch startLatch = new CountDownLatch(numTasks);
CountDownLatch endLatch = new CountDownLatch(numTasks);
Expand Down Expand Up @@ -95,4 +93,16 @@ void testConcurrent() throws InterruptedException {
endLatch.await();
Assertions.assertEquals(maxTokens, numAcquired.get());
}

private void assertCanAcquire(TokenBucket tokenBucket, int times) {
for (int i = 0; i < times; i++) {
Assertions.assertTrue(tokenBucket.tryAcquire());
}
}

private void assertCannotAcquire(TokenBucket tokenBucket) {
for (int i = 0; i < 5; i++) {
Assertions.assertFalse(tokenBucket.tryAcquire());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ contract={org.apache.polaris.service.catalog.io.FileIOFactory}
name=test
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.dropwizard.ratelimiter.MockRealmTokenBucketRateLimiter]S
contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=mock-realm-token-bucket
[org.apache.polaris.service.dropwizard.ratelimiter.MockTokenBucketFactory]S
contract={org.apache.polaris.service.ratelimiter.TokenBucketFactory}
name=mock
qualifier={io.smallrye.common.annotation.Identifier}
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,12 @@ logging:
# Limits the size of request bodies sent to Polaris. -1 means no limit.
maxRequestBodyBytes: 1000000

# Limits the request rate per realm
# Limits the request rate per realm.
rateLimiter:
type: realm-token-bucket

# The token bucket factory to use when using the realm-token-bucket rate limiter.
tokenBucketFactory:
type: default
requestsPerSecond: 9999
windowSeconds: 10
8 changes: 8 additions & 0 deletions polaris-server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,11 @@ maxRequestBodyBytes: -1
# Optional, not specifying a "rateLimiter" section also means no rate limiter
rateLimiter:
type: no-op
# Uncomment to use the realm-token-bucket rate limiter
# type: realm-token-bucket

# The token bucket factory to use when using the realm-token-bucket rate limiter.
tokenBucketFactory:
type: default
requestsPerSecond: 9999
windowSeconds: 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.polaris.service.ratelimiter;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.smallrye.common.annotation.Identifier;
import java.time.Clock;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.polaris.core.context.RealmContext;

@Identifier("default")
public class DefaultTokenBucketFactory implements TokenBucketFactory {

private final long requestsPerSecond;
private final long windowSeconds;
private final Clock clock;
private final Map<String, TokenBucket> perRealmBuckets = new ConcurrentHashMap<>();

@JsonCreator
public DefaultTokenBucketFactory(
@JsonProperty("requestsPerSecond") long requestsPerSecond,
@JsonProperty("windowSeconds") long windowSeconds) {
this(requestsPerSecond, windowSeconds, Clock.systemUTC());
}

public DefaultTokenBucketFactory(long requestsPerSecond, long windowSeconds, Clock clock) {
this.requestsPerSecond = requestsPerSecond;
this.windowSeconds = windowSeconds;
this.clock = clock;
}

@Override
public TokenBucket getOrCreateTokenBucket(RealmContext realmContext) {
String realmId = realmContext.getRealmIdentifier();
return perRealmBuckets.computeIfAbsent(
realmId,
k ->
new TokenBucket(
requestsPerSecond, Math.multiplyExact(requestsPerSecond, windowSeconds), clock));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@Identifier("no-op")
public class NoOpRateLimiter implements RateLimiter {
@Override
public boolean tryAcquire() {
public boolean canProceed() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ public interface RateLimiter {
*
* @return Whether the request is allowed to proceed by the rate limiter
*/
boolean tryAcquire();
boolean canProceed();
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public RateLimiterFilter(RateLimiter rateLimiter) {
/** Returns a 429 if the rate limiter says so. Otherwise, forwards the request along. */
@Override
public void filter(ContainerRequestContext ctx) throws IOException {
if (!rateLimiter.tryAcquire()) {
if (!rateLimiter.canProceed()) {
ctx.abortWith(Response.status(Response.Status.TOO_MANY_REQUESTS).build());
LOGGER.atDebug().log("Rate limiting request");
}
Expand Down
Loading

0 comments on commit 1071aa2

Please sign in to comment.