Skip to content

Commit

Permalink
Change AccessType to a class from an enum
Browse files Browse the repository at this point in the history
This will allow new AccessTypes to be created outside of this
project by third parties. Uses a technique where by the `ResourceSecurity`
instantiates the AccessType assuming a no-arg constructor.
  • Loading branch information
Randgalt committed Jan 27, 2025
1 parent 69f4d2c commit 03c68f8
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
*/
package io.trino.aws.proxy.server.rest;

import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.PublicAccess;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.SigV4Access;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Access.PublicAccess;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Access.SigV4Access;
import io.trino.aws.proxy.spi.signing.SigningServiceType;

import java.lang.annotation.Retention;
Expand All @@ -29,32 +29,65 @@
@Target({TYPE, METHOD})
public @interface ResourceSecurity
{
enum AccessType
sealed interface Access
{
PUBLIC(new PublicAccess()),
S3(new SigV4Access(SigningServiceType.S3)),
STS(new SigV4Access(SigningServiceType.STS)),
LOGS(new SigV4Access(SigningServiceType.LOGS));
record PublicAccess()
implements Access {}

public sealed interface Access
record SigV4Access(SigningServiceType signingServiceType)
implements Access {}
}

abstract class AccessType
{
private final Access access;

public Access access()
{
record PublicAccess() implements Access {}
return access;
}

record SigV4Access(SigningServiceType signingServiceType) implements Access {}
protected AccessType(Access access)
{
this.access = requireNonNull(access, "access is null");
}
}

private final Access access;
class Public
extends AccessType
{
public Public()
{
super(new PublicAccess());
}
}

AccessType(Access access)
class S3
extends AccessType
{
public S3()
{
this.access = requireNonNull(access, "access is null");
super(new SigV4Access(SigningServiceType.S3));
}
}

public Access access()
class Sts
extends AccessType
{
public Sts()
{
return access;
super(new SigV4Access(SigningServiceType.STS));
}
}

class Logs
extends AccessType
{
public Logs()
{
super(new SigV4Access(SigningServiceType.LOGS));
}
}

AccessType value();
Class<? extends AccessType> value();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
package io.trino.aws.proxy.server.rest;

import com.google.inject.Inject;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Access.PublicAccess;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Access.SigV4Access;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.PublicAccess;
import io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.Access.SigV4Access;
import io.trino.aws.proxy.spi.signing.SigningController;
import io.trino.aws.proxy.spi.signing.SigningServiceType;
import jakarta.ws.rs.container.DynamicFeature;
Expand Down Expand Up @@ -64,6 +64,14 @@ private static AccessType getAccessType(ResourceInfo resourceInfo)
private static Optional<AccessType> getAccessTypeFromAnnotation(AnnotatedElement annotatedElement)
{
return Optional.ofNullable(annotatedElement.getAnnotation(ResourceSecurity.class))
.map(ResourceSecurity::value);
.map(ResourceSecurity::value)
.map(accessClass -> {
try {
return accessClass.getConstructor().newInstance();
}
catch (Exception e) {
throw new IllegalArgumentException("Could not instantiate access type. Ensure it has a no-arg constructor. Class: " + accessClass, e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.trino.aws.proxy.server.rest.RequestLoggerController.SaveEntry;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Logs;
import io.trino.aws.proxy.server.rest.TrinoLogsResource.GetLogEventsResponse.Event;
import io.trino.aws.proxy.spi.rest.Request;
import jakarta.ws.rs.HeaderParam;
Expand All @@ -38,13 +39,12 @@
import java.util.function.Predicate;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.LOGS;
import static io.trino.aws.proxy.spi.signing.SigningServiceType.S3;
import static io.trino.aws.proxy.spi.signing.SigningServiceType.STS;
import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(LOGS)
@ResourceSecurity(Logs.class)
public class TrinoLogsResource
{
private static final Set<String> DEFAULT_STREAMS = ImmutableSet.of(S3.serviceName(), STS.serviceName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.inject.Inject;
import io.trino.aws.proxy.server.TrinoAwsProxyConfig;
import io.trino.aws.proxy.server.rest.ResourceSecurity.S3;
import io.trino.aws.proxy.spi.rest.ParsedS3Request;
import io.trino.aws.proxy.spi.rest.Request;
import io.trino.aws.proxy.spi.signing.SigningMetadata;
Expand All @@ -33,10 +34,9 @@
import java.util.Optional;

import static io.trino.aws.proxy.server.rest.RequestBuilder.fromRequest;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.S3;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(S3)
@ResourceSecurity(S3.class)
public class TrinoS3Resource
{
private final TrinoS3ProxyClient proxyClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.inject.Inject;
import com.sun.management.OperatingSystemMXBean;
import io.airlift.node.NodeInfo;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Public;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.HEAD;
import jakarta.ws.rs.Produces;
Expand All @@ -25,11 +26,10 @@
import java.lang.management.MemoryMXBean;

import static io.airlift.units.Duration.nanosSince;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.PUBLIC;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(PUBLIC)
@ResourceSecurity(Public.class)
public class TrinoStatusResource
{
private final NodeInfo nodeInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.airlift.log.Logger;
import io.trino.aws.proxy.server.rest.AssumeRoleResponse.AssumeRoleResult;
import io.trino.aws.proxy.server.rest.AssumeRoleResponse.AssumedRoleUser;
import io.trino.aws.proxy.server.rest.ResourceSecurity.Sts;
import io.trino.aws.proxy.spi.credentials.AssumedRoleProvider;
import io.trino.aws.proxy.spi.credentials.EmulatedAssumedRole;
import io.trino.aws.proxy.spi.rest.Request;
Expand All @@ -37,10 +38,9 @@
import java.util.Optional;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.aws.proxy.server.rest.ResourceSecurity.AccessType.STS;
import static java.util.Objects.requireNonNull;

@ResourceSecurity(STS)
@ResourceSecurity(Sts.class)
public class TrinoStsResource
{
private static final Logger log = Logger.get(TrinoStsResource.class);
Expand Down

0 comments on commit 03c68f8

Please sign in to comment.