Skip to content

Commit

Permalink
Support @AuthenticationPrincipal on interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
kse-music committed Dec 6, 2024
1 parent dc82a6e commit 4bab2ed
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
Expand Down Expand Up @@ -83,6 +84,7 @@
*
* @param <A> the annotation to search for and synthesize
* @author Josh Cummings
* @author DingHao
* @since 6.4
*/
final class UniqueSecurityAnnotationScanner<A extends Annotation> extends AbstractSecurityAnnotationScanner<A> {
Expand All @@ -107,7 +109,7 @@ final class UniqueSecurityAnnotationScanner<A extends Annotation> extends Abstra
MergedAnnotation<A> merge(AnnotatedElement element, Class<?> targetClass) {
if (element instanceof Parameter parameter) {
return this.uniqueParameterAnnotationCache.computeIfAbsent(parameter, (p) -> {
List<MergedAnnotation<A>> annotations = findDirectAnnotations(p);
List<MergedAnnotation<A>> annotations = findParameterAnnotations(p);
return requireUnique(p, annotations);
});
}
Expand Down Expand Up @@ -137,6 +139,50 @@ private MergedAnnotation<A> requireUnique(AnnotatedElement element, List<MergedA
};
}

private List<MergedAnnotation<A>> findParameterAnnotations(Parameter current) {
List<MergedAnnotation<A>> directAnnotations = findDirectAnnotations(current);
if (!directAnnotations.isEmpty()) {
return directAnnotations;
}
directAnnotations = new ArrayList<>(findDirectAnnotations(current));
Executable executable = current.getDeclaringExecutable();
if (executable instanceof Method method) {
Class<?> clazz = method.getDeclaringClass();
Set<Class<?>> visited = new HashSet<>();
while (clazz != null && visited.add(clazz)) {
for (Class<?> ifc : clazz.getInterfaces()) {
directAnnotations.addAll(findParameterAnnotations(method, ifc, current));
}
clazz = clazz.getSuperclass();
if (clazz == Object.class) {
clazz = null;
}
if (clazz != null && visited.add(clazz)) {
directAnnotations.addAll(findParameterAnnotations(method, clazz, current));
}
}
}
return directAnnotations;
}

private List<MergedAnnotation<A>> findParameterAnnotations(Method method, Class<?> superOrIfc, Parameter current) {
try {
Method methodToUse = superOrIfc.getDeclaredMethod(method.getName(), method.getParameterTypes());
for (Parameter parameter : methodToUse.getParameters()) {
if (parameter.getName().equals(current.getName())) {
List<MergedAnnotation<A>> directAnnotations = findDirectAnnotations(parameter);
if (!directAnnotations.isEmpty()) {
return directAnnotations;
}
}
}
}
catch (NoSuchMethodException ex) {
// move on
}
return Collections.emptyList();
}

private List<MergedAnnotation<A>> findMethodAnnotations(Method method, Class<?> targetClass) {
// The method may be on an interface, but we need attributes from the target
// class.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

package org.springframework.security.core.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.List;

import org.junit.jupiter.api.Test;

Expand All @@ -34,6 +40,9 @@ public class UniqueSecurityAnnotationScannerTests {
private UniqueSecurityAnnotationScanner<PreAuthorize> scanner = new UniqueSecurityAnnotationScanner<>(
PreAuthorize.class);

private UniqueSecurityAnnotationScanner<CustomParameterAnnotation> parameterScanner = new UniqueSecurityAnnotationScanner<>(
CustomParameterAnnotation.class);

@Test
void scanWhenAnnotationOnInterfaceThenResolves() throws Exception {
Method method = AnnotationOnInterface.class.getDeclaredMethod("method");
Expand Down Expand Up @@ -251,6 +260,77 @@ void scanWhenClassInheritingAbstractClassNoAnnotationsThenNoAnnotation() throws
assertThat(preAuthorize).isNull();
}

@Test
void scanParameterAnnotationWhenAnnotationOnInterface() throws Exception {
Parameter parameter = UserService.class.getDeclaredMethod("add", String.class).getParameters()[0];
CustomParameterAnnotation customParameterAnnotation = this.parameterScanner.scan(parameter);
assertThat(customParameterAnnotation.value()).isEqualTo("one");
}

@Test
void scanParameterAnnotationWhenClassInheritingInterfaceAnnotation() throws Exception {
Parameter parameter = UserServiceImpl.class.getDeclaredMethod("add", String.class).getParameters()[0];
CustomParameterAnnotation customParameterAnnotation = this.parameterScanner.scan(parameter);
assertThat(customParameterAnnotation.value()).isEqualTo("one");
}

@Test
void scanParameterAnnotationWhenClassOverridingMethodOverridingInterface() throws Exception {
Parameter parameter = UserServiceImpl.class.getDeclaredMethod("get", String.class).getParameters()[0];
CustomParameterAnnotation customParameterAnnotation = this.parameterScanner.scan(parameter);
assertThat(customParameterAnnotation.value()).isEqualTo("five");
}

@Test
void scanParameterAnnotationWhenMultipleMethodInheritanceThenException() throws Exception {
Parameter parameter = UserServiceImpl.class.getDeclaredMethod("list", String.class).getParameters()[0];
assertThatExceptionOfType(AnnotationConfigurationException.class)
.isThrownBy(() -> this.parameterScanner.scan(parameter));
}

interface UserService {

void add(@CustomParameterAnnotation("one") String user);

List<String> list(@CustomParameterAnnotation("two") String user);

String get(@CustomParameterAnnotation("three") String user);

}

interface OtherUserService {

List<String> list(@CustomParameterAnnotation("four") String user);

}

static class UserServiceImpl implements UserService, OtherUserService {

@Override
public void add(String user) {

}

@Override
public List<String> list(String user) {
return List.of(user);
}

@Override
public String get(@CustomParameterAnnotation("five") String user) {
return user;
}

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@interface CustomParameterAnnotation {

String value();

}

@PreAuthorize("one")
private interface AnnotationOnInterface {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import java.lang.annotation.Annotation;

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.core.annotation.RepeatableContainers;
import org.springframework.expression.BeanResolver;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
Expand Down Expand Up @@ -98,8 +102,12 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet

private ExpressionParser parser = new SpelExpressionParser();

private final Class<AuthenticationPrincipal> annotationType = AuthenticationPrincipal.class;

private SecurityAnnotationScanner<AuthenticationPrincipal> scanner = SecurityAnnotationScanners
.requireUnique(AuthenticationPrincipal.class);
.requireUnique(this.annotationType);

private boolean useAnnotationTemplate = false;

private BeanResolver beanResolver;

Expand Down Expand Up @@ -164,7 +172,8 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
* @since 6.4
*/
public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDefaults) {
this.scanner = SecurityAnnotationScanners.requireUnique(AuthenticationPrincipal.class, templateDefaults);
this.useAnnotationTemplate = templateDefaults != null;
this.scanner = SecurityAnnotationScanners.requireUnique(this.annotationType, templateDefaults);
}

/**
Expand All @@ -173,9 +182,16 @@ public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDef
* @param parameter the {@link MethodParameter} to search for an {@link Annotation}
* @return the {@link Annotation} that was found or null.
*/
@SuppressWarnings("unchecked")
private <T extends Annotation> T findMethodAnnotation(MethodParameter parameter) {
return (T) this.scanner.scan(parameter.getParameter());
private AuthenticationPrincipal findMethodAnnotation(MethodParameter parameter) {
if (this.useAnnotationTemplate) {
return this.scanner.scan(parameter.getParameter());
}
return MergedAnnotations
.from(parameter.getParameter(), MergedAnnotations.SearchStrategy.INHERITED_ANNOTATIONS,
RepeatableContainers.none())
.get(this.annotationType)
.synthesize(MergedAnnotation::isPresent)
.orElse(null);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.AnnotationConfigurationException;
import org.springframework.expression.BeanResolver;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
Expand Down Expand Up @@ -214,6 +215,17 @@ public void resolveArgumentCustomMetaAnnotationTpl() throws Exception {
.isEqualTo(this.expectedPrincipal);
}

@Test
public void resolveArgumentAnnotationFromInterface() {
CustomUserPrincipal principal = new CustomUserPrincipal();
setAuthenticationPrincipal(principal);
this.resolver.setTemplateDefaults(new AnnotationTemplateExpressionDefaults());
assertThat(this.resolver.supportsParameter(getMethodParameter("getUserByInterface", CustomUserPrincipal.class)))
.isTrue();
assertThatExceptionOfType(AnnotationConfigurationException.class).isThrownBy(() -> this.resolver
.resolveArgument(getMethodParameter("username", CustomUserPrincipal.class), null, null, null));
}

private MethodParameter showUserNoAnnotation() {
return getMethodParameter("showUserNoAnnotation", String.class);
}
Expand Down Expand Up @@ -312,7 +324,31 @@ private void setAuthenticationPrincipal(Object principal) {

}

public static class TestController {
interface UserApi {

String getUserByInterface(@AuthenticationPrincipal CustomUserPrincipal user);

Object username(@AuthenticationPrincipal CustomUserPrincipal user);

}

interface UserPublicApi {

Object username(@AuthenticationPrincipal CustomUserPrincipal user);

}

public static class TestController implements UserApi, UserPublicApi {

@Override
public String getUserByInterface(CustomUserPrincipal user) {
return "";
}

@Override
public Object username(CustomUserPrincipal user) {
return user.getPrincipal();
}

public void showUserNoAnnotation(String user) {
}
Expand Down

0 comments on commit 4bab2ed

Please sign in to comment.