Skip to content

Commit

Permalink
Change AccessType to a class from an enum
Browse files Browse the repository at this point in the history
This will allow new AccessTypes to be created outside of this
project by third parties. Uses a technique where by the `ResourceSecurity`
instantiates the AccessType assuming a no-arg constructor.
  • Loading branch information
Randgalt committed Jan 27, 2025
1 parent 925bd8e commit 21ef861
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
*/
package io.trino.aws.proxy.server.rest;

import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.PublicAccess;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.SigV4Access;
import io.trino.aws.proxy.spi.signing.SigningServiceType;

import java.lang.annotation.Retention;
Expand All @@ -23,38 +21,59 @@
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static java.util.Objects.requireNonNull;

@Retention(RUNTIME)
@Target({TYPE, METHOD})
public @interface ResourceSecurity
{
enum AccessType
sealed interface AccessType
{
PUBLIC(new PublicAccess()),
S3(new SigV4Access(SigningServiceType.S3)),
STS(new SigV4Access(SigningServiceType.STS)),
LOGS(new SigV4Access(SigningServiceType.LOGS));
}

public sealed interface Access
{
record PublicAccess() implements Access {}
sealed interface PublicAccessType
extends AccessType
permits Public
{}

record SigV4Access(SigningServiceType signingServiceType) implements Access {}
}
non-sealed interface SigV4AccessType
extends AccessType
{
SigningServiceType signingServiceType();
}

private final Access access;
final class Public
implements PublicAccessType
{}

AccessType(Access access)
final class S3
implements SigV4AccessType
{
@Override
public SigningServiceType signingServiceType()
{
this.access = requireNonNull(access, "access is null");
return SigningServiceType.S3;
}
}

public Access access()
final class Sts
implements SigV4AccessType
{
@Override
public SigningServiceType signingServiceType()
{
return SigningServiceType.STS;
}
}

final class Logs
implements SigV4AccessType
{
@Override
public SigningServiceType signingServiceType()
{
return access;
return SigningServiceType.LOGS;
}
}

AccessType value();
Class<? extends AccessType> value();
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

import com.google.inject.Inject;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.PublicAccess;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.SigV4Access;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Public;
import io.trino.aws.proxy.server.rest.ResourceSecurity.SigV4AccessType;
import io.trino.aws.proxy.spi.signing.SigningController;
import io.trino.aws.proxy.spi.signing.SigningServiceType;
import jakarta.ws.rs.container.DynamicFeature;
import jakarta.ws.rs.container.ResourceInfo;
import jakarta.ws.rs.core.FeatureContext;
Expand Down Expand Up @@ -46,10 +45,10 @@ public void configure(ResourceInfo resourceInfo, FeatureContext context)
{
if (resourceInfo.getResourceClass().getPackageName().startsWith("io.trino.aws")) {
AccessType accessType = getAccessType(resourceInfo);
switch (accessType.access()) {
case PublicAccess _ -> {}
case SigV4Access(SigningServiceType signingServiceType) ->
context.register(new SecurityFilter(signingController, signingServiceType, requestLoggerController));
switch (accessType) {
case Public _ -> {}
case SigV4AccessType sigV4AccessType ->
context.register(new SecurityFilter(signingController, sigV4AccessType.signingServiceType(), requestLoggerController));
}
}
}
Expand All @@ -64,6 +63,14 @@ private static AccessType getAccessType(ResourceInfo resourceInfo)
private static Optional<AccessType> getAccessTypeFromAnnotation(AnnotatedElement annotatedElement)
{
return Optional.ofNullable(annotatedElement.getAnnotation(ResourceSecurity.class))
.map(ResourceSecurity::value);
.map(ResourceSecurity::value)
.map(accessClass -> {
try {
return accessClass.getConstructor().newInstance();
}
catch (Exception e) {
throw new IllegalArgumentException("Could not instantiate access type. Ensure it has a no-arg constructor. Class: " + accessClass, e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.trino.aws.proxy.server.rest.RequestLoggerController.SaveEntry;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Logs;
import io.trino.aws.proxy.server.rest.TrinoLogsResource.GetLogEventsResponse.Event;
import io.trino.aws.proxy.spi.rest.Request;
import jakarta.ws.rs.HeaderParam;
Expand All @@ -38,13 +39,12 @@
import java.util.function.Predicate;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.LOGS;
import static io.trino.aws.proxy.spi.signing.SigningServiceType.S3;
import static io.trino.aws.proxy.spi.signing.SigningServiceType.STS;
import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(LOGS)
@ResourceSecurity(Logs.class)
public class TrinoLogsResource
{
private static final Set<String> DEFAULT_STREAMS = ImmutableSet.of(S3.serviceName(), STS.serviceName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.inject.Inject;
import io.trino.aws.proxy.server.TrinoAwsProxyConfig;
import io.trino.aws.proxy.server.rest.ResourceSecurity.S3;
import io.trino.aws.proxy.spi.rest.ParsedS3Request;
import io.trino.aws.proxy.spi.rest.Request;
import io.trino.aws.proxy.spi.signing.SigningMetadata;
Expand All @@ -33,10 +34,9 @@
import java.util.Optional;

import static io.trino.aws.proxy.server.rest.RequestBuilder.fromRequest;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.S3;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(S3)
@ResourceSecurity(S3.class)
public class TrinoS3Resource
{
private final TrinoS3ProxyClient proxyClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.inject.Inject;
import com.sun.management.OperatingSystemMXBean;
import io.airlift.node.NodeInfo;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Public;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.HEAD;
import jakarta.ws.rs.Produces;
Expand All @@ -25,11 +26,10 @@
import java.lang.management.MemoryMXBean;

import static io.airlift.units.Duration.nanosSince;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.PUBLIC;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(PUBLIC)
@ResourceSecurity(Public.class)
public class TrinoStatusResource
{
private final NodeInfo nodeInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.airlift.log.Logger;
import io.trino.aws.proxy.server.rest.AssumeRoleResponse.AssumeRoleResult;
import io.trino.aws.proxy.server.rest.AssumeRoleResponse.AssumedRoleUser;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Sts;
import io.trino.aws.proxy.spi.credentials.AssumedRoleProvider;
import io.trino.aws.proxy.spi.credentials.EmulatedAssumedRole;
import io.trino.aws.proxy.spi.rest.Request;
Expand All @@ -37,10 +38,9 @@
import java.util.Optional;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.STS;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(STS)
@ResourceSecurity(Sts.class)
public class TrinoStsResource
{
private static final Logger log = Logger.get(TrinoStsResource.class);
Expand Down

0 comments on commit 21ef861

Please sign in to comment.