1 /*
   2  * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.test;
  24 
  25 import java.io.PrintStream;
  26 import java.io.PrintWriter;
  27 import java.lang.reflect.Field;
  28 import java.lang.reflect.Method;
  29 import java.util.Arrays;
  30 
  31 import org.junit.Assert;
  32 import org.junit.internal.ComparisonCriteria;
  33 import org.junit.internal.ExactComparisonCriteria;
  34 
  35 import sun.misc.Unsafe;
  36 
  37 /**
  38  * Base class that contains common utility methods and classes useful in unit tests.
  39  */
  40 public class GraalTest {
  41 
  42     public static final Unsafe UNSAFE;
  43     static {
  44         try {
  45             Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
  46             theUnsafe.setAccessible(true);
  47             UNSAFE = (Unsafe) theUnsafe.get(Unsafe.class);
  48         } catch (Exception e) {
  49             throw new RuntimeException("exception while trying to get Unsafe", e);
  50         }
  51     }
  52 
  53     public static final boolean Java8OrEarlier = System.getProperty("java.specification.version").compareTo("1.9") < 0;
  54 
  55     protected Method getMethod(String methodName) {
  56         return getMethod(getClass(), methodName);
  57     }
  58 
  59     protected Method getMethod(Class<?> clazz, String methodName) {
  60         Method found = null;
  61         for (Method m : clazz.getMethods()) {
  62             if (m.getName().equals(methodName)) {
  63                 Assert.assertNull(found);
  64                 found = m;
  65             }
  66         }
  67         if (found == null) {
  68             /* Now look for non-public methods (but this does not look in superclasses). */
  69             for (Method m : clazz.getDeclaredMethods()) {
  70                 if (m.getName().equals(methodName)) {
  71                     Assert.assertNull(found);
  72                     found = m;
  73                 }
  74             }
  75         }
  76         if (found != null) {
  77             return found;
  78         } else {
  79             throw new RuntimeException("method not found: " + methodName);
  80         }
  81     }
  82 
  83     protected Method getMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) {
  84         try {
  85             return clazz.getMethod(methodName, parameterTypes);
  86         } catch (NoSuchMethodException | SecurityException e) {
  87             throw new RuntimeException("method not found: " + methodName + "" + Arrays.toString(parameterTypes));
  88         }
  89     }
  90 
  91     /**
  92      * Compares two given objects for {@linkplain Assert#assertEquals(Object, Object) equality}.
  93      * Does a deep copy equality comparison if {@code expected} is an array.
  94      */
  95     protected void assertDeepEquals(Object expected, Object actual) {
  96         assertDeepEquals(null, expected, actual);
  97     }
  98 
  99     /**
 100      * Compares two given objects for {@linkplain Assert#assertEquals(Object, Object) equality}.
 101      * Does a deep copy equality comparison if {@code expected} is an array.
 102      *
 103      * @param message the identifying message for the {@link AssertionError}
 104      */
 105     protected void assertDeepEquals(String message, Object expected, Object actual) {
 106         if (ulpsDelta() > 0) {
 107             assertDeepEquals(message, expected, actual, ulpsDelta());
 108         } else {
 109             assertDeepEquals(message, expected, actual, equalFloatsOrDoublesDelta());
 110         }
 111     }
 112 
 113     /**
 114      * Compares two given values for equality, doing a recursive test if both values are arrays of
 115      * the same type.
 116      *
 117      * @param message the identifying message for the {@link AssertionError}
 118      * @param delta the maximum delta between two doubles or floats for which both numbers are still
 119      *            considered equal.
 120      */
 121     protected void assertDeepEquals(String message, Object expected, Object actual, double delta) {
 122         if (expected != null && actual != null) {
 123             Class<?> expectedClass = expected.getClass();
 124             Class<?> actualClass = actual.getClass();
 125             if (expectedClass.isArray()) {
 126                 Assert.assertEquals(message, expectedClass, actual.getClass());
 127                 if (expected instanceof int[]) {
 128                     Assert.assertArrayEquals(message, (int[]) expected, (int[]) actual);
 129                 } else if (expected instanceof byte[]) {
 130                     Assert.assertArrayEquals(message, (byte[]) expected, (byte[]) actual);
 131                 } else if (expected instanceof char[]) {
 132                     Assert.assertArrayEquals(message, (char[]) expected, (char[]) actual);
 133                 } else if (expected instanceof short[]) {
 134                     Assert.assertArrayEquals(message, (short[]) expected, (short[]) actual);
 135                 } else if (expected instanceof float[]) {
 136                     Assert.assertArrayEquals(message, (float[]) expected, (float[]) actual, (float) delta);
 137                 } else if (expected instanceof long[]) {
 138                     Assert.assertArrayEquals(message, (long[]) expected, (long[]) actual);
 139                 } else if (expected instanceof double[]) {
 140                     Assert.assertArrayEquals(message, (double[]) expected, (double[]) actual, delta);
 141                 } else if (expected instanceof boolean[]) {
 142                     new ExactComparisonCriteria().arrayEquals(message, expected, actual);
 143                 } else if (expected instanceof Object[]) {
 144                     new ComparisonCriteria() {
 145                         @Override
 146                         protected void assertElementsEqual(Object e, Object a) {
 147                             assertDeepEquals(message, e, a, delta);
 148                         }
 149                     }.arrayEquals(message, expected, actual);
 150                 } else {
 151                     Assert.fail((message == null ? "" : message) + "non-array value encountered: " + expected);
 152                 }
 153             } else if (expectedClass.equals(double.class) && actualClass.equals(double.class)) {
 154                 Assert.assertEquals((double) expected, (double) actual, delta);
 155             } else if (expectedClass.equals(float.class) && actualClass.equals(float.class)) {
 156                 Assert.assertEquals((float) expected, (float) actual, delta);
 157             } else {
 158                 Assert.assertEquals(message, expected, actual);
 159             }
 160         } else {
 161             Assert.assertEquals(message, expected, actual);
 162         }
 163     }
 164 
 165     /**
 166      * Compares two given values for equality, doing a recursive test if both values are arrays of
 167      * the same type. Uses {@linkplain StrictMath#ulp(float) ULP}s for comparison of floats.
 168      *
 169      * @param message the identifying message for the {@link AssertionError}
 170      * @param ulpsDelta the maximum allowed ulps difference between two doubles or floats for which
 171      *            both numbers are still considered equal.
 172      */
 173     protected void assertDeepEquals(String message, Object expected, Object actual, int ulpsDelta) {
 174         ComparisonCriteria doubleUlpsDeltaCriteria = new ComparisonCriteria() {
 175             @Override
 176             protected void assertElementsEqual(Object e, Object a) {
 177                 assertTrue(message, e instanceof Double && a instanceof Double);
 178                 // determine acceptable error based on whether it is a normal number or a NaN/Inf
 179                 double de = (Double) e;
 180                 double epsilon = (!Double.isNaN(de) && Double.isFinite(de) ? ulpsDelta * Math.ulp(de) : 0);
 181                 Assert.assertEquals(message, (Double) e, (Double) a, epsilon);
 182             }
 183         };
 184 
 185         ComparisonCriteria floatUlpsDeltaCriteria = new ComparisonCriteria() {
 186             @Override
 187             protected void assertElementsEqual(Object e, Object a) {
 188                 assertTrue(message, e instanceof Float && a instanceof Float);
 189                 // determine acceptable error based on whether it is a normal number or a NaN/Inf
 190                 float fe = (Float) e;
 191                 float epsilon = (!Float.isNaN(fe) && Float.isFinite(fe) ? ulpsDelta * Math.ulp(fe) : 0);
 192                 Assert.assertEquals(message, (Float) e, (Float) a, epsilon);
 193             }
 194         };
 195 
 196         if (expected != null && actual != null) {
 197             Class<?> expectedClass = expected.getClass();
 198             Class<?> actualClass = actual.getClass();
 199             if (expectedClass.isArray()) {
 200                 Assert.assertEquals(message, expectedClass, actualClass);
 201                 if (expected instanceof double[] || expected instanceof Object[]) {
 202                     doubleUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 203                     return;
 204                 } else if (expected instanceof float[] || expected instanceof Object[]) {
 205                     floatUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 206                     return;
 207                 }
 208             } else if (expectedClass.equals(double.class) && actualClass.equals(double.class)) {
 209                 doubleUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 210                 return;
 211             } else if (expectedClass.equals(float.class) && actualClass.equals(float.class)) {
 212                 floatUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 213                 return;
 214             }
 215         }
 216         // anything else just use the non-ulps version
 217         assertDeepEquals(message, expected, actual, equalFloatsOrDoublesDelta());
 218     }
 219 
 220     /**
 221      * Gets the value used by {@link #assertDeepEquals(Object, Object)} and
 222      * {@link #assertDeepEquals(String, Object, Object)} for the maximum delta between two doubles
 223      * or floats for which both numbers are still considered equal.
 224      */
 225     protected double equalFloatsOrDoublesDelta() {
 226         return 0.0D;
 227     }
 228 
 229     // unless overridden ulpsDelta is not used
 230     protected int ulpsDelta() {
 231         return 0;
 232     }
 233 
 234     @SuppressWarnings("serial")
 235     public static class MultiCauseAssertionError extends AssertionError {
 236 
 237         private Throwable[] causes;
 238 
 239         public MultiCauseAssertionError(String message, Throwable... causes) {
 240             super(message);
 241             this.causes = causes;
 242         }
 243 
 244         @Override
 245         public void printStackTrace(PrintStream out) {
 246             super.printStackTrace(out);
 247             int num = 0;
 248             for (Throwable cause : causes) {
 249                 if (cause != null) {
 250                     out.print("cause " + (num++));
 251                     cause.printStackTrace(out);
 252                 }
 253             }
 254         }
 255 
 256         @Override
 257         public void printStackTrace(PrintWriter out) {
 258             super.printStackTrace(out);
 259             int num = 0;
 260             for (Throwable cause : causes) {
 261                 if (cause != null) {
 262                     out.print("cause " + (num++) + ": ");
 263                     cause.printStackTrace(out);
 264                 }
 265             }
 266         }
 267     }
 268 
 269     /*
 270      * Overrides to the normal JUnit {@link Assert} routines that provide varargs style formatting
 271      * and produce an exception stack trace with the assertion frames trimmed out.
 272      */
 273 
 274     /**
 275      * Fails a test with the given message.
 276      *
 277      * @param message the identifying message for the {@link AssertionError} (<code>null</code>
 278      *            okay)
 279      * @see AssertionError
 280      */
 281     public static void fail(String message, Object... objects) {
 282         AssertionError e;
 283         if (message == null) {
 284             e = new AssertionError();
 285         } else {
 286             e = new AssertionError(String.format(message, objects));
 287         }
 288         // Trim the assert frames from the stack trace
 289         StackTraceElement[] trace = e.getStackTrace();
 290         int start = 1; // Skip this frame
 291         String thisClassName = GraalTest.class.getName();
 292         while (start < trace.length && trace[start].getClassName().equals(thisClassName) && (trace[start].getMethodName().equals("assertTrue") || trace[start].getMethodName().equals("assertFalse"))) {
 293             start++;
 294         }
 295         e.setStackTrace(Arrays.copyOfRange(trace, start, trace.length));
 296         throw e;
 297     }
 298 
 299     /**
 300      * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} with the
 301      * given message.
 302      *
 303      * @param message the identifying message for the {@link AssertionError} (<code>null</code>
 304      *            okay)
 305      * @param condition condition to be checked
 306      */
 307     public static void assertTrue(String message, boolean condition) {
 308         assertTrue(condition, message);
 309     }
 310 
 311     /**
 312      * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} without a
 313      * message.
 314      *
 315      * @param condition condition to be checked
 316      */
 317     public static void assertTrue(boolean condition) {
 318         assertTrue(condition, null);
 319     }
 320 
 321     /**
 322      * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} with the
 323      * given message.
 324      *
 325      * @param message the identifying message for the {@link AssertionError} (<code>null</code>
 326      *            okay)
 327      * @param condition condition to be checked
 328      */
 329     public static void assertFalse(String message, boolean condition) {
 330         assertTrue(!condition, message);
 331     }
 332 
 333     /**
 334      * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} without a
 335      * message.
 336      *
 337      * @param condition condition to be checked
 338      */
 339     public static void assertFalse(boolean condition) {
 340         assertTrue(!condition, null);
 341     }
 342 
 343     /**
 344      * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} with the
 345      * given message.
 346      *
 347      * @param condition condition to be checked
 348      * @param message the identifying message for the {@link AssertionError}
 349      * @param objects arguments to the format string
 350      */
 351     public static void assertTrue(boolean condition, String message, Object... objects) {
 352         if (!condition) {
 353             fail(message, objects);
 354         }
 355     }
 356 
 357     /**
 358      * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} with the
 359      * given message produced by {@link String#format}.
 360      *
 361      * @param condition condition to be checked
 362      * @param message the identifying message for the {@link AssertionError}
 363      * @param objects arguments to the format string
 364      */
 365     public static void assertFalse(boolean condition, String message, Object... objects) {
 366         assertTrue(!condition, message, objects);
 367     }
 368 }