Skip to content

Commit

Permalink
@BeforeParam/@AfterParam for Parameterized runner (junit-team#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
panchenko authored and kcooney committed Apr 21, 2017
1 parent 62ba19a commit b7e6d12
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,19 @@ public void evaluate() throws Throwable {
} finally {
for (FrameworkMethod each : afters) {
try {
each.invokeExplosively(target);
invokeMethod(each);
} catch (Throwable e) {
errors.add(e);
}
}
}
MultipleFailureException.assertEmpty(errors);
}

/**
* @since 4.13
*/
protected void invokeMethod(FrameworkMethod method) throws Throwable {
method.invokeExplosively(target);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ public RunBefores(Statement next, List<FrameworkMethod> befores, Object target)
@Override
public void evaluate() throws Throwable {
for (FrameworkMethod before : befores) {
before.invokeExplosively(target);
invokeMethod(before);
}
next.evaluate();
}

/**
* @since 4.13
*/
protected void invokeMethod(FrameworkMethod method) throws Throwable {
method.invokeExplosively(target);
}
}
132 changes: 110 additions & 22 deletions src/main/java/org/junit/runners/Parameterized.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.junit.runners;

import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
Expand All @@ -8,11 +9,13 @@
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import org.junit.runner.Runner;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InvalidTestClassError;
import org.junit.runners.model.TestClass;
import org.junit.runners.parameterized.BlockJUnit4ClassRunnerWithParametersFactory;
import org.junit.runners.parameterized.ParametersRunnerFactory;
Expand Down Expand Up @@ -134,6 +137,19 @@
* }
* </pre>
*
* <h3>Executing code before/after executing tests for specific parameters</h3>
* <p>
* If your test needs to perform some preparation or cleanup based on the
* parameters, this can be done by adding public static methods annotated with
* {@code @BeforeParam}/{@code @AfterParam}. Such methods should either have no
* parameters or the same parameters as the test.
* <pre>
* &#064;BeforeParam
* public static void beforeTestsForParameter(String onlyParameter) {
* System.out.println("Testing " + onlyParameter);
* }
* </pre>
*
* <h3>Create different runners</h3>
* <p>
* By default the {@code Parameterized} runner creates a slightly modified
Expand Down Expand Up @@ -234,32 +250,91 @@ public class Parameterized extends Suite {
Class<? extends ParametersRunnerFactory> value() default BlockJUnit4ClassRunnerWithParametersFactory.class;
}

/**
* Annotation for {@code public static void} methods which should be executed before
* evaluating tests with particular parameters.
*
* @see org.junit.BeforeClass
* @see org.junit.Before
* @since 4.13
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface BeforeParam {
}

/**
* Annotation for {@code public static void} methods which should be executed after
* evaluating tests with particular parameters.
*
* @see org.junit.AfterClass
* @see org.junit.After
* @since 4.13
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface AfterParam {
}

/**
* Only called reflectively. Do not use programmatically.
*/
public Parameterized(Class<?> klass) throws Throwable {
super(klass, RunnersFactory.createRunnersForClass(klass));
this(klass, new RunnersFactory(klass));
}

private Parameterized(Class<?> klass, RunnersFactory runnersFactory) throws Exception {
super(klass, runnersFactory.createRunners());
validateBeforeParamAndAfterParamMethods(runnersFactory.parameterCount);
}

private void validateBeforeParamAndAfterParamMethods(Integer parameterCount)
throws InvalidTestClassError {
List<Throwable> errors = new ArrayList<Throwable>();
validatePublicStaticVoidMethods(Parameterized.BeforeParam.class, parameterCount, errors);
validatePublicStaticVoidMethods(Parameterized.AfterParam.class, parameterCount, errors);
if (!errors.isEmpty()) {
throw new InvalidTestClassError(getTestClass().getJavaClass(), errors);
}
}

private void validatePublicStaticVoidMethods(
Class<? extends Annotation> annotation, Integer parameterCount,
List<Throwable> errors) {
List<FrameworkMethod> methods = getTestClass().getAnnotatedMethods(annotation);
for (FrameworkMethod fm : methods) {
fm.validatePublicVoid(true, errors);
if (parameterCount != null) {
int methodParameterCount = fm.getMethod().getParameterTypes().length;
if (methodParameterCount != 0 && methodParameterCount != parameterCount) {
errors.add(new Exception("Method " + fm.getName()
+ "() should have 0 or " + parameterCount + " parameter(s)"));
}
}
}
}

private static class RunnersFactory {
private static final ParametersRunnerFactory DEFAULT_FACTORY = new BlockJUnit4ClassRunnerWithParametersFactory();

private final TestClass testClass;
private final FrameworkMethod parametersMethod;
private final List<Object> allParameters;
private final int parameterCount;

static List<Runner> createRunnersForClass(Class<?> klass)
throws Throwable {
return new RunnersFactory(klass).createRunners();
}

private RunnersFactory(Class<?> klass) {
private RunnersFactory(Class<?> klass) throws Throwable {
testClass = new TestClass(klass);
parametersMethod = getParametersMethod(testClass);
allParameters = allParameters(testClass, parametersMethod);
parameterCount =
allParameters.isEmpty() ? 0 : normalizeParameters(allParameters.get(0)).length;
}

private List<Runner> createRunners() throws Throwable {
Parameters parameters = getParametersMethod().getAnnotation(
Parameters.class);
private List<Runner> createRunners() throws Exception {
Parameters parameters = parametersMethod.getAnnotation(Parameters.class);
return Collections.unmodifiableList(createRunnersForParameters(
allParameters(), parameters.name(),
allParameters, parameters.name(),
getParametersRunnerFactory()));
}

Expand All @@ -278,25 +353,37 @@ private ParametersRunnerFactory getParametersRunnerFactory()

private TestWithParameters createTestWithNotNormalizedParameters(
String pattern, int index, Object parametersOrSingleParameter) {
Object[] parameters = (parametersOrSingleParameter instanceof Object[]) ? (Object[]) parametersOrSingleParameter
Object[] parameters = normalizeParameters(parametersOrSingleParameter);
return createTestWithParameters(testClass, pattern, index, parameters);
}

private static Object[] normalizeParameters(Object parametersOrSingleParameter) {
return (parametersOrSingleParameter instanceof Object[]) ? (Object[]) parametersOrSingleParameter
: new Object[] { parametersOrSingleParameter };
return createTestWithParameters(testClass, pattern, index,
parameters);
}

@SuppressWarnings("unchecked")
private Iterable<Object> allParameters() throws Throwable {
Object parameters = getParametersMethod().invokeExplosively(null);
if (parameters instanceof Iterable) {
return (Iterable<Object>) parameters;
private static List<Object> allParameters(
TestClass testClass, FrameworkMethod parametersMethod) throws Throwable {
Object parameters = parametersMethod.invokeExplosively(null);
if (parameters instanceof List) {
return (List<Object>) parameters;
} else if (parameters instanceof Collection) {
return new ArrayList<Object>((Collection<Object>) parameters);
} else if (parameters instanceof Iterable) {
List<Object> result = new ArrayList<Object>();
for (Object entry : ((Iterable<Object>) parameters)) {
result.add(entry);
}
return result;
} else if (parameters instanceof Object[]) {
return Arrays.asList((Object[]) parameters);
} else {
throw parametersMethodReturnedWrongType();
throw parametersMethodReturnedWrongType(testClass, parametersMethod);
}
}

private FrameworkMethod getParametersMethod() throws Exception {
private static FrameworkMethod getParametersMethod(TestClass testClass) throws Exception {
List<FrameworkMethod> methods = testClass
.getAnnotatedMethods(Parameters.class);
for (FrameworkMethod each : methods) {
Expand All @@ -322,7 +409,7 @@ private List<Runner> createRunnersForParameters(
}
return runners;
} catch (ClassCastException e) {
throw parametersMethodReturnedWrongType();
throw parametersMethodReturnedWrongType(testClass, parametersMethod);
}
}

Expand All @@ -338,9 +425,10 @@ private List<TestWithParameters> createTestsForParameters(
return children;
}

private Exception parametersMethodReturnedWrongType() throws Exception {
private static Exception parametersMethodReturnedWrongType(
TestClass testClass, FrameworkMethod parametersMethod) throws Exception {
String className = testClass.getName();
String methodName = getParametersMethod().getName();
String methodName = parametersMethod.getName();
String message = MessageFormat.format(
"{0}.{1}() must return an Iterable of arrays.", className,
methodName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import java.lang.reflect.Field;
import java.util.List;

import org.junit.internal.runners.statements.RunAfters;
import org.junit.internal.runners.statements.RunBefores;
import org.junit.runner.RunWith;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.BlockJUnit4ClassRunner;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.model.FrameworkField;
import org.junit.runners.model.FrameworkMethod;
Expand Down Expand Up @@ -135,7 +138,46 @@ protected void validateFields(List<Throwable> errors) {

@Override
protected Statement classBlock(RunNotifier notifier) {
return childrenInvoker(notifier);
Statement statement = childrenInvoker(notifier);
statement = withBeforeParams(statement);
statement = withAfterParams(statement);
return statement;
}

private Statement withBeforeParams(Statement statement) {
List<FrameworkMethod> befores = getTestClass()
.getAnnotatedMethods(Parameterized.BeforeParam.class);
return befores.isEmpty() ? statement : new RunBeforeParams(statement, befores);
}

private class RunBeforeParams extends RunBefores {
RunBeforeParams(Statement next, List<FrameworkMethod> befores) {
super(next, befores, null);
}

@Override
protected void invokeMethod(FrameworkMethod method) throws Throwable {
int paramCount = method.getMethod().getParameterTypes().length;
method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
}
}

private Statement withAfterParams(Statement statement) {
List<FrameworkMethod> afters = getTestClass()
.getAnnotatedMethods(Parameterized.AfterParam.class);
return afters.isEmpty() ? statement : new RunAfterParams(statement, afters);
}

private class RunAfterParams extends RunAfters {
RunAfterParams(Statement next, List<FrameworkMethod> afters) {
super(next, afters, null);
}

@Override
protected void invokeMethod(FrameworkMethod method) throws Throwable {
int paramCount = method.getMethod().getParameterTypes().length;
method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
}
}

@Override
Expand Down
Loading

0 comments on commit b7e6d12

Please sign in to comment.