diff --git a/spring-test/src/test/java/org/springframework/test/BeanMethodAnnotationLookupTests.java b/spring-test/src/test/java/org/springframework/test/BeanMethodAnnotationLookupTests.java index 4dc3215df6bc..08aa88a11cc8 100644 --- a/spring-test/src/test/java/org/springframework/test/BeanMethodAnnotationLookupTests.java +++ b/spring-test/src/test/java/org/springframework/test/BeanMethodAnnotationLookupTests.java @@ -8,6 +8,7 @@ import java.lang.annotation.Target; import java.lang.reflect.Method; import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; @@ -16,14 +17,13 @@ import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; -import org.springframework.beans.factory.support.AbstractBeanDefinition; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; -import org.springframework.util.ClassUtils; -import org.springframework.util.ReflectionUtils; +import static java.util.stream.Collectors.*; import static org.junit.jupiter.api.Assertions.*; @SpringJUnitConfig @@ -70,33 +70,29 @@ static class MyBeanFactoryPostProcessor implements BeanFactoryPostProcessor { public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { Class annotationType = MyAnnotation.class; - Arrays.stream(beanFactory.getBeanDefinitionNames())// - .map(beanFactory::getBeanDefinition)// - .filter(bd -> !bd.isAbstract())// - .filter(bd -> bd.getFactoryMethodName() != null)// - .filter(bd -> bd.getFactoryBeanName() != null)// - .filter(bd -> isBeanMethodAnnotated(beanFactory, bd, annotationType))// + findAnnotatedBeanDefinitions(beanFactory, annotationType)// .forEach(bd -> System.out.println("@Bean method " + bd.getFactoryMethodName() + " is annotated with @" + annotationType.getSimpleName() + ".")); } - private boolean isBeanMethodAnnotated(ConfigurableListableBeanFactory beanFactory, - BeanDefinition beanDefinition, Class annotationType) { + private List findAnnotatedBeanDefinitions(ConfigurableListableBeanFactory beanFactory, + Class annotationType) { - BeanDefinition factoryBeanDefinition = beanFactory.getBeanDefinition(beanDefinition.getFactoryBeanName()); + return Arrays.stream(beanFactory.getBeanDefinitionNames())// + .map(beanFactory::getBeanDefinition)// + .filter(bd -> !bd.isAbstract())// + .filter(RootBeanDefinition.class::isInstance)// + .map(RootBeanDefinition.class::cast)// + .filter(bd -> isAnnotatedBeanMethod(bd, annotationType))// + .collect(toList()); + } - if (factoryBeanDefinition instanceof AbstractBeanDefinition) { - AbstractBeanDefinition abd = (AbstractBeanDefinition) factoryBeanDefinition; - if (abd.hasBeanClass()) { - Class factoryClass = ClassUtils.getUserClass(abd.getBeanClass()); - String factoryMethodName = beanDefinition.getFactoryMethodName(); - Method factoryMethod = ReflectionUtils.findMethod(factoryClass, factoryMethodName); - return AnnotatedElementUtils.isAnnotated(factoryMethod, Bean.class) - && AnnotatedElementUtils.isAnnotated(factoryMethod, annotationType); - } - } + private boolean isAnnotatedBeanMethod(RootBeanDefinition beanDefinition, + Class annotationType) { - return false; + Method factoryMethod = beanDefinition.getResolvedFactoryMethod(); + return factoryMethod != null && AnnotatedElementUtils.isAnnotated(factoryMethod, Bean.class) + && AnnotatedElementUtils.isAnnotated(factoryMethod, annotationType); } }