001    package org.junit.experimental.theories;
002    
003    import java.lang.reflect.Constructor;
004    import java.lang.reflect.Field;
005    import java.lang.reflect.Method;
006    import java.lang.reflect.Modifier;
007    import java.util.ArrayList;
008    import java.util.List;
009    
010    import org.junit.Assert;
011    import org.junit.Assume;
012    import org.junit.experimental.theories.internal.Assignments;
013    import org.junit.experimental.theories.internal.ParameterizedAssertionError;
014    import org.junit.internal.AssumptionViolatedException;
015    import org.junit.runners.BlockJUnit4ClassRunner;
016    import org.junit.runners.model.FrameworkMethod;
017    import org.junit.runners.model.InitializationError;
018    import org.junit.runners.model.Statement;
019    import org.junit.runners.model.TestClass;
020    
021    public class Theories extends BlockJUnit4ClassRunner {
022        public Theories(Class<?> klass) throws InitializationError {
023            super(klass);
024        }
025    
026        @Override
027        protected void collectInitializationErrors(List<Throwable> errors) {
028            super.collectInitializationErrors(errors);
029            validateDataPointFields(errors);
030            validateDataPointMethods(errors);
031        }
032    
033        private void validateDataPointFields(List<Throwable> errors) {
034            Field[] fields = getTestClass().getJavaClass().getDeclaredFields();
035    
036            for (Field field : fields) {
037                if (field.getAnnotation(DataPoint.class) == null && field.getAnnotation(DataPoints.class) == null) {
038                    continue;
039                }
040                if (!Modifier.isStatic(field.getModifiers())) {
041                    errors.add(new Error("DataPoint field " + field.getName() + " must be static"));
042                }
043                if (!Modifier.isPublic(field.getModifiers())) {
044                    errors.add(new Error("DataPoint field " + field.getName() + " must be public"));
045                }
046            }
047        }
048    
049        private void validateDataPointMethods(List<Throwable> errors) {
050            Method[] methods = getTestClass().getJavaClass().getDeclaredMethods();
051            
052            for (Method method : methods) {
053                if (method.getAnnotation(DataPoint.class) == null && method.getAnnotation(DataPoints.class) == null) {
054                    continue;
055                }
056                if (!Modifier.isStatic(method.getModifiers())) {
057                    errors.add(new Error("DataPoint method " + method.getName() + " must be static"));
058                }
059                if (!Modifier.isPublic(method.getModifiers())) {
060                    errors.add(new Error("DataPoint method " + method.getName() + " must be public"));
061                }
062            }
063        }
064    
065        @Override
066        protected void validateConstructor(List<Throwable> errors) {
067            validateOnlyOneConstructor(errors);
068        }
069    
070        @Override
071        protected void validateTestMethods(List<Throwable> errors) {
072            for (FrameworkMethod each : computeTestMethods()) {
073                if (each.getAnnotation(Theory.class) != null) {
074                    each.validatePublicVoid(false, errors);
075                    each.validateNoTypeParametersOnArgs(errors);
076                } else {
077                    each.validatePublicVoidNoArg(false, errors);
078                }
079                
080                for (ParameterSignature signature : ParameterSignature.signatures(each.getMethod())) {
081                    ParametersSuppliedBy annotation = signature.findDeepAnnotation(ParametersSuppliedBy.class);
082                    if (annotation != null) {
083                        validateParameterSupplier(annotation.value(), errors);
084                    }
085                }
086            }
087        }
088    
089        private void validateParameterSupplier(Class<? extends ParameterSupplier> supplierClass, List<Throwable> errors) {
090            Constructor<?>[] constructors = supplierClass.getConstructors();
091            
092            if (constructors.length != 1) {
093                errors.add(new Error("ParameterSupplier " + supplierClass.getName() + 
094                                     " must have only one constructor (either empty or taking only a TestClass)"));
095            } else {
096                Class<?>[] paramTypes = constructors[0].getParameterTypes();
097                if (!(paramTypes.length == 0) && !paramTypes[0].equals(TestClass.class)) {
098                    errors.add(new Error("ParameterSupplier " + supplierClass.getName() + 
099                                         " constructor must take either nothing or a single TestClass instance"));
100                }
101            }
102        }
103    
104        @Override
105        protected List<FrameworkMethod> computeTestMethods() {
106            List<FrameworkMethod> testMethods = new ArrayList<FrameworkMethod>(super.computeTestMethods());
107            List<FrameworkMethod> theoryMethods = getTestClass().getAnnotatedMethods(Theory.class);
108            testMethods.removeAll(theoryMethods);
109            testMethods.addAll(theoryMethods);
110            return testMethods;
111        }
112    
113        @Override
114        public Statement methodBlock(final FrameworkMethod method) {
115            return new TheoryAnchor(method, getTestClass());
116        }
117    
118        public static class TheoryAnchor extends Statement {
119            private int successes = 0;
120    
121            private final FrameworkMethod testMethod;
122            private final TestClass testClass;
123    
124            private List<AssumptionViolatedException> fInvalidParameters = new ArrayList<AssumptionViolatedException>();
125    
126            public TheoryAnchor(FrameworkMethod testMethod, TestClass testClass) {
127                this.testMethod = testMethod;
128                this.testClass = testClass;
129            }
130    
131            private TestClass getTestClass() {
132                return testClass;
133            }
134    
135            @Override
136            public void evaluate() throws Throwable {
137                runWithAssignment(Assignments.allUnassigned(
138                        testMethod.getMethod(), getTestClass()));
139                
140                //if this test method is not annotated with Theory, then no successes is a valid case
141                boolean hasTheoryAnnotation = testMethod.getAnnotation(Theory.class) != null;
142                if (successes == 0 && hasTheoryAnnotation) {
143                    Assert
144                            .fail("Never found parameters that satisfied method assumptions.  Violated assumptions: "
145                                    + fInvalidParameters);
146                }
147            }
148    
149            protected void runWithAssignment(Assignments parameterAssignment)
150                    throws Throwable {
151                if (!parameterAssignment.isComplete()) {
152                    runWithIncompleteAssignment(parameterAssignment);
153                } else {
154                    runWithCompleteAssignment(parameterAssignment);
155                }
156            }
157    
158            protected void runWithIncompleteAssignment(Assignments incomplete)
159                    throws Throwable {
160                for (PotentialAssignment source : incomplete
161                        .potentialsForNextUnassigned()) {
162                    runWithAssignment(incomplete.assignNext(source));
163                }
164            }
165    
166            protected void runWithCompleteAssignment(final Assignments complete)
167                    throws Throwable {
168                new BlockJUnit4ClassRunner(getTestClass().getJavaClass()) {
169                    @Override
170                    protected void collectInitializationErrors(
171                            List<Throwable> errors) {
172                        // do nothing
173                    }
174    
175                    @Override
176                    public Statement methodBlock(FrameworkMethod method) {
177                        final Statement statement = super.methodBlock(method);
178                        return new Statement() {
179                            @Override
180                            public void evaluate() throws Throwable {
181                                try {
182                                    statement.evaluate();
183                                    handleDataPointSuccess();
184                                } catch (AssumptionViolatedException e) {
185                                    handleAssumptionViolation(e);
186                                } catch (Throwable e) {
187                                    reportParameterizedError(e, complete
188                                            .getArgumentStrings(nullsOk()));
189                                }
190                            }
191    
192                        };
193                    }
194    
195                    @Override
196                    protected Statement methodInvoker(FrameworkMethod method, Object test) {
197                        return methodCompletesWithParameters(method, complete, test);
198                    }
199    
200                    @Override
201                    public Object createTest() throws Exception {
202                        Object[] params = complete.getConstructorArguments();
203                        
204                        if (!nullsOk()) {
205                            Assume.assumeNotNull(params);
206                        }
207                        
208                        return getTestClass().getOnlyConstructor().newInstance(params);
209                    }
210                }.methodBlock(testMethod).evaluate();
211            }
212    
213            private Statement methodCompletesWithParameters(
214                    final FrameworkMethod method, final Assignments complete, final Object freshInstance) {
215                return new Statement() {
216                    @Override
217                    public void evaluate() throws Throwable {
218                        final Object[] values = complete.getMethodArguments();
219                        
220                        if (!nullsOk()) {
221                            Assume.assumeNotNull(values);
222                        }
223                        
224                        method.invokeExplosively(freshInstance, values);
225                    }
226                };
227            }
228    
229            protected void handleAssumptionViolation(AssumptionViolatedException e) {
230                fInvalidParameters.add(e);
231            }
232    
233            protected void reportParameterizedError(Throwable e, Object... params)
234                    throws Throwable {
235                if (params.length == 0) {
236                    throw e;
237                }
238                throw new ParameterizedAssertionError(e, testMethod.getName(),
239                        params);
240            }
241    
242            private boolean nullsOk() {
243                Theory annotation = testMethod.getMethod().getAnnotation(
244                        Theory.class);
245                if (annotation == null) {
246                    return false;
247                }
248                return annotation.nullsAccepted();
249            }
250    
251            protected void handleDataPointSuccess() {
252                successes++;
253            }
254        }
255    }