001/*
002 * Copyright (C) 2012 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.testing;
018
019import static com.google.common.base.Preconditions.checkArgument;
020import static com.google.common.base.Preconditions.checkNotNull;
021import static com.google.common.base.Throwables.throwIfUnchecked;
022import static junit.framework.Assert.assertEquals;
023import static junit.framework.Assert.fail;
024
025import com.google.common.annotations.GwtIncompatible;
026import com.google.common.annotations.J2ktIncompatible;
027import com.google.common.base.Function;
028import com.google.common.base.Throwables;
029import com.google.common.collect.Lists;
030import com.google.common.reflect.AbstractInvocationHandler;
031import com.google.common.reflect.Reflection;
032import com.google.errorprone.annotations.CanIgnoreReturnValue;
033import java.lang.reflect.AccessibleObject;
034import java.lang.reflect.InvocationTargetException;
035import java.lang.reflect.Method;
036import java.lang.reflect.Modifier;
037import java.util.List;
038import java.util.concurrent.atomic.AtomicInteger;
039import org.jspecify.annotations.NullMarked;
040import org.checkerframework.checker.nullness.qual.Nullable;
041
042/**
043 * Tester to ensure forwarding wrapper works by delegating calls to the corresponding method with
044 * the same parameters forwarded and return value forwarded back or exception propagated as is.
045 *
046 * <p>For example:
047 *
048 * <pre>{@code
049 * new ForwardingWrapperTester().testForwarding(Foo.class, new Function<Foo, Foo>() {
050 *   public Foo apply(Foo foo) {
051 *     return new ForwardingFoo(foo);
052 *   }
053 * });
054 * }</pre>
055 *
056 * @author Ben Yu
057 * @since 14.0
058 */
059@GwtIncompatible
060@J2ktIncompatible
061@NullMarked
062public final class ForwardingWrapperTester {
063
064  private boolean testsEquals = false;
065
066  /**
067   * Asks for {@link Object#equals} and {@link Object#hashCode} to be tested. That is, forwarding
068   * wrappers of equal instances should be equal.
069   */
070  @CanIgnoreReturnValue
071  public ForwardingWrapperTester includingEquals() {
072    this.testsEquals = true;
073    return this;
074  }
075
076  /**
077   * Tests that the forwarding wrapper returned by {@code wrapperFunction} properly forwards method
078   * calls with parameters passed as is, return value returned as is, and exceptions propagated as
079   * is.
080   */
081  public <T> void testForwarding(
082      Class<T> interfaceType, Function<? super T, ? extends T> wrapperFunction) {
083    checkNotNull(wrapperFunction);
084    checkArgument(interfaceType.isInterface(), "%s isn't an interface", interfaceType);
085    Method[] methods = getMostConcreteMethods(interfaceType);
086    AccessibleObject.setAccessible(methods, true);
087    for (Method method : methods) {
088      // Interfaces can have default methods that aren't abstract.
089      // No need to verify them.
090      // Can't check isDefault() for Android compatibility.
091      if (!Modifier.isAbstract(method.getModifiers())) {
092        continue;
093      }
094      // The interface could be package-private or private.
095      // filter out equals/hashCode/toString
096      if (method.getName().equals("equals")
097          && method.getParameterTypes().length == 1
098          && method.getParameterTypes()[0] == Object.class) {
099        continue;
100      }
101      if (method.getName().equals("hashCode") && method.getParameterTypes().length == 0) {
102        continue;
103      }
104      if (method.getName().equals("toString") && method.getParameterTypes().length == 0) {
105        continue;
106      }
107      testSuccessfulForwarding(interfaceType, method, wrapperFunction);
108      testExceptionPropagation(interfaceType, method, wrapperFunction);
109    }
110    if (testsEquals) {
111      testEquals(interfaceType, wrapperFunction);
112    }
113    testToString(interfaceType, wrapperFunction);
114  }
115
116  /** Returns the most concrete public methods from {@code type}. */
117  private static Method[] getMostConcreteMethods(Class<?> type) {
118    Method[] methods = type.getMethods();
119    for (int i = 0; i < methods.length; i++) {
120      try {
121        methods[i] = type.getMethod(methods[i].getName(), methods[i].getParameterTypes());
122      } catch (Exception e) {
123        throwIfUnchecked(e);
124        throw new RuntimeException(e);
125      }
126    }
127    return methods;
128  }
129
130  private static <T> void testSuccessfulForwarding(
131      Class<T> interfaceType, Method method, Function<? super T, ? extends T> wrapperFunction) {
132    new InteractionTester<T>(interfaceType, method).testInteraction(wrapperFunction);
133  }
134
135  private static <T> void testExceptionPropagation(
136      Class<T> interfaceType, Method method, Function<? super T, ? extends T> wrapperFunction) {
137    RuntimeException exception = new RuntimeException();
138    T proxy =
139        Reflection.newProxy(
140            interfaceType,
141            new AbstractInvocationHandler() {
142              @Override
143              protected Object handleInvocation(Object p, Method m, @Nullable Object[] args)
144                  throws Throwable {
145                throw exception;
146              }
147            });
148    T wrapper = wrapperFunction.apply(proxy);
149    try {
150      method.invoke(wrapper, getParameterValues(method));
151      fail(method + " failed to throw exception as is.");
152    } catch (InvocationTargetException e) {
153      if (exception != e.getCause()) {
154        throw new RuntimeException(e);
155      }
156    } catch (IllegalAccessException e) {
157      throw new AssertionError(e);
158    }
159  }
160
161  private static <T> void testEquals(
162      Class<T> interfaceType, Function<? super T, ? extends T> wrapperFunction) {
163    FreshValueGenerator generator = new FreshValueGenerator();
164    T instance = generator.newFreshProxy(interfaceType);
165    new EqualsTester()
166        .addEqualityGroup(wrapperFunction.apply(instance), wrapperFunction.apply(instance))
167        .addEqualityGroup(wrapperFunction.apply(generator.newFreshProxy(interfaceType)))
168        // TODO: add an overload to EqualsTester to print custom error message?
169        .testEquals();
170  }
171
172  private static <T> void testToString(
173      Class<T> interfaceType, Function<? super T, ? extends T> wrapperFunction) {
174    T proxy = new FreshValueGenerator().newFreshProxy(interfaceType);
175    assertEquals(
176        "toString() isn't properly forwarded",
177        proxy.toString(),
178        wrapperFunction.apply(proxy).toString());
179  }
180
181  private static @Nullable Object[] getParameterValues(Method method) {
182    FreshValueGenerator paramValues = new FreshValueGenerator();
183    List<@Nullable Object> passedArgs = Lists.newArrayList();
184    for (Class<?> paramType : method.getParameterTypes()) {
185      passedArgs.add(paramValues.generateFresh(paramType));
186    }
187    return passedArgs.toArray();
188  }
189
190  /** Tests a single interaction against a method. */
191  private static final class InteractionTester<T> extends AbstractInvocationHandler {
192
193    private final Class<T> interfaceType;
194    private final Method method;
195    private final @Nullable Object[] passedArgs;
196    private final @Nullable Object returnValue;
197    private final AtomicInteger called = new AtomicInteger();
198
199    InteractionTester(Class<T> interfaceType, Method method) {
200      this.interfaceType = interfaceType;
201      this.method = method;
202      this.passedArgs = getParameterValues(method);
203      this.returnValue = new FreshValueGenerator().generateFresh(method.getReturnType());
204    }
205
206    @Override
207    protected @Nullable Object handleInvocation(
208        Object p, Method calledMethod, @Nullable Object[] args) throws Throwable {
209      assertEquals(method, calledMethod);
210      assertEquals(method + " invoked more than once.", 0, called.get());
211      for (int i = 0; i < passedArgs.length; i++) {
212        assertEquals(
213            "Parameter #" + i + " of " + method + " not forwarded", passedArgs[i], args[i]);
214      }
215      called.getAndIncrement();
216      return returnValue;
217    }
218
219    void testInteraction(Function<? super T, ? extends T> wrapperFunction) {
220      T proxy = Reflection.newProxy(interfaceType, this);
221      T wrapper = wrapperFunction.apply(proxy);
222      boolean isPossibleChainingCall = interfaceType.isAssignableFrom(method.getReturnType());
223      try {
224        Object actualReturnValue = method.invoke(wrapper, passedArgs);
225        // If we think this might be a 'chaining' call then we allow the return value to either
226        // be the wrapper or the returnValue.
227        if (!isPossibleChainingCall || wrapper != actualReturnValue) {
228          assertEquals(
229              "Return value of " + method + " not forwarded", returnValue, actualReturnValue);
230        }
231      } catch (IllegalAccessException e) {
232        throw new RuntimeException(e);
233      } catch (InvocationTargetException e) {
234        throw Throwables.propagate(e.getCause());
235      }
236      assertEquals("Failed to forward to " + method, 1, called.get());
237    }
238
239    @Override
240    public String toString() {
241      return "dummy " + interfaceType.getSimpleName();
242    }
243  }
244}