1 /*
   2  * Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.test;
  24 
  25 import static org.graalvm.compiler.debug.DebugContext.DEFAULT_LOG_STREAM;
  26 import static org.graalvm.compiler.debug.DebugContext.NO_DESCRIPTION;
  27 import static org.graalvm.compiler.debug.DebugContext.NO_GLOBAL_METRIC_VALUES;
  28 
  29 import java.io.PrintStream;
  30 import java.io.PrintWriter;
  31 import java.lang.reflect.Field;
  32 import java.lang.reflect.Method;
  33 import java.util.ArrayList;
  34 import java.util.Arrays;
  35 import java.util.Collection;
  36 import java.util.Collections;
  37 import java.util.List;
  38 
  39 import org.graalvm.compiler.debug.DebugHandlersFactory;
  40 import org.graalvm.compiler.debug.DebugContext;
  41 import org.graalvm.compiler.debug.DebugDumpHandler;
  42 import org.graalvm.compiler.options.OptionValues;
  43 import org.junit.After;
  44 import org.junit.Assert;
  45 import org.junit.internal.ComparisonCriteria;
  46 import org.junit.internal.ExactComparisonCriteria;
  47 
  48 import sun.misc.Unsafe;
  49 
  50 /**
  51  * Base class that contains common utility methods and classes useful in unit tests.
  52  */
  53 public class GraalTest {
  54 
  55     public static final Unsafe UNSAFE;
  56     static {
  57         try {
  58             Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
  59             theUnsafe.setAccessible(true);
  60             UNSAFE = (Unsafe) theUnsafe.get(Unsafe.class);
  61         } catch (Exception e) {
  62             throw new RuntimeException("exception while trying to get Unsafe", e);
  63         }
  64     }
  65 
  66     public static final boolean Java8OrEarlier = System.getProperty("java.specification.version").compareTo("1.9") < 0;
  67 
  68     protected Method getMethod(String methodName) {
  69         return getMethod(getClass(), methodName);
  70     }
  71 
  72     protected Method getMethod(Class<?> clazz, String methodName) {
  73         Method found = null;
  74         for (Method m : clazz.getMethods()) {
  75             if (m.getName().equals(methodName)) {
  76                 Assert.assertNull(found);
  77                 found = m;
  78             }
  79         }
  80         if (found == null) {
  81             /* Now look for non-public methods (but this does not look in superclasses). */
  82             for (Method m : clazz.getDeclaredMethods()) {
  83                 if (m.getName().equals(methodName)) {
  84                     Assert.assertNull(found);
  85                     found = m;
  86                 }
  87             }
  88         }
  89         if (found != null) {
  90             return found;
  91         } else {
  92             throw new RuntimeException("method not found: " + methodName);
  93         }
  94     }
  95 
  96     protected Method getMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) {
  97         try {
  98             return clazz.getMethod(methodName, parameterTypes);
  99         } catch (NoSuchMethodException | SecurityException e) {
 100             throw new RuntimeException("method not found: " + methodName + "" + Arrays.toString(parameterTypes));
 101         }
 102     }
 103 
 104     /**
 105      * Compares two given objects for {@linkplain Assert#assertEquals(Object, Object) equality}.
 106      * Does a deep copy equality comparison if {@code expected} is an array.
 107      */
 108     protected void assertDeepEquals(Object expected, Object actual) {
 109         assertDeepEquals(null, expected, actual);
 110     }
 111 
 112     /**
 113      * Compares two given objects for {@linkplain Assert#assertEquals(Object, Object) equality}.
 114      * Does a deep copy equality comparison if {@code expected} is an array.
 115      *
 116      * @param message the identifying message for the {@link AssertionError}
 117      */
 118     protected void assertDeepEquals(String message, Object expected, Object actual) {
 119         if (ulpsDelta() > 0) {
 120             assertDeepEquals(message, expected, actual, ulpsDelta());
 121         } else {
 122             assertDeepEquals(message, expected, actual, equalFloatsOrDoublesDelta());
 123         }
 124     }
 125 
 126     /**
 127      * Compares two given values for equality, doing a recursive test if both values are arrays of
 128      * the same type.
 129      *
 130      * @param message the identifying message for the {@link AssertionError}
 131      * @param delta the maximum delta between two doubles or floats for which both numbers are still
 132      *            considered equal.
 133      */
 134     protected void assertDeepEquals(String message, Object expected, Object actual, double delta) {
 135         if (expected != null && actual != null) {
 136             Class<?> expectedClass = expected.getClass();
 137             Class<?> actualClass = actual.getClass();
 138             if (expectedClass.isArray()) {
 139                 Assert.assertEquals(message, expectedClass, actual.getClass());
 140                 if (expected instanceof int[]) {
 141                     Assert.assertArrayEquals(message, (int[]) expected, (int[]) actual);
 142                 } else if (expected instanceof byte[]) {
 143                     Assert.assertArrayEquals(message, (byte[]) expected, (byte[]) actual);
 144                 } else if (expected instanceof char[]) {
 145                     Assert.assertArrayEquals(message, (char[]) expected, (char[]) actual);
 146                 } else if (expected instanceof short[]) {
 147                     Assert.assertArrayEquals(message, (short[]) expected, (short[]) actual);
 148                 } else if (expected instanceof float[]) {
 149                     Assert.assertArrayEquals(message, (float[]) expected, (float[]) actual, (float) delta);
 150                 } else if (expected instanceof long[]) {
 151                     Assert.assertArrayEquals(message, (long[]) expected, (long[]) actual);
 152                 } else if (expected instanceof double[]) {
 153                     Assert.assertArrayEquals(message, (double[]) expected, (double[]) actual, delta);
 154                 } else if (expected instanceof boolean[]) {
 155                     new ExactComparisonCriteria().arrayEquals(message, expected, actual);
 156                 } else if (expected instanceof Object[]) {
 157                     new ComparisonCriteria() {
 158                         @Override
 159                         protected void assertElementsEqual(Object e, Object a) {
 160                             assertDeepEquals(message, e, a, delta);
 161                         }
 162                     }.arrayEquals(message, expected, actual);
 163                 } else {
 164                     Assert.fail((message == null ? "" : message) + "non-array value encountered: " + expected);
 165                 }
 166             } else if (expectedClass.equals(double.class) && actualClass.equals(double.class)) {
 167                 Assert.assertEquals((double) expected, (double) actual, delta);
 168             } else if (expectedClass.equals(float.class) && actualClass.equals(float.class)) {
 169                 Assert.assertEquals((float) expected, (float) actual, delta);
 170             } else {
 171                 Assert.assertEquals(message, expected, actual);
 172             }
 173         } else {
 174             Assert.assertEquals(message, expected, actual);
 175         }
 176     }
 177 
 178     /**
 179      * Compares two given values for equality, doing a recursive test if both values are arrays of
 180      * the same type. Uses {@linkplain StrictMath#ulp(float) ULP}s for comparison of floats.
 181      *
 182      * @param message the identifying message for the {@link AssertionError}
 183      * @param ulpsDelta the maximum allowed ulps difference between two doubles or floats for which
 184      *            both numbers are still considered equal.
 185      */
 186     protected void assertDeepEquals(String message, Object expected, Object actual, int ulpsDelta) {
 187         ComparisonCriteria doubleUlpsDeltaCriteria = new ComparisonCriteria() {
 188             @Override
 189             protected void assertElementsEqual(Object e, Object a) {
 190                 assertTrue(message, e instanceof Double && a instanceof Double);
 191                 // determine acceptable error based on whether it is a normal number or a NaN/Inf
 192                 double de = (Double) e;
 193                 double epsilon = (!Double.isNaN(de) && Double.isFinite(de) ? ulpsDelta * Math.ulp(de) : 0);
 194                 Assert.assertEquals(message, (Double) e, (Double) a, epsilon);
 195             }
 196         };
 197 
 198         ComparisonCriteria floatUlpsDeltaCriteria = new ComparisonCriteria() {
 199             @Override
 200             protected void assertElementsEqual(Object e, Object a) {
 201                 assertTrue(message, e instanceof Float && a instanceof Float);
 202                 // determine acceptable error based on whether it is a normal number or a NaN/Inf
 203                 float fe = (Float) e;
 204                 float epsilon = (!Float.isNaN(fe) && Float.isFinite(fe) ? ulpsDelta * Math.ulp(fe) : 0);
 205                 Assert.assertEquals(message, (Float) e, (Float) a, epsilon);
 206             }
 207         };
 208 
 209         if (expected != null && actual != null) {
 210             Class<?> expectedClass = expected.getClass();
 211             Class<?> actualClass = actual.getClass();
 212             if (expectedClass.isArray()) {
 213                 Assert.assertEquals(message, expectedClass, actualClass);
 214                 if (expected instanceof double[] || expected instanceof Object[]) {
 215                     doubleUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 216                     return;
 217                 } else if (expected instanceof float[] || expected instanceof Object[]) {
 218                     floatUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 219                     return;
 220                 }
 221             } else if (expectedClass.equals(double.class) && actualClass.equals(double.class)) {
 222                 doubleUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 223                 return;
 224             } else if (expectedClass.equals(float.class) && actualClass.equals(float.class)) {
 225                 floatUlpsDeltaCriteria.arrayEquals(message, expected, actual);
 226                 return;
 227             }
 228         }
 229         // anything else just use the non-ulps version
 230         assertDeepEquals(message, expected, actual, equalFloatsOrDoublesDelta());
 231     }
 232 
 233     /**
 234      * Gets the value used by {@link #assertDeepEquals(Object, Object)} and
 235      * {@link #assertDeepEquals(String, Object, Object)} for the maximum delta between two doubles
 236      * or floats for which both numbers are still considered equal.
 237      */
 238     protected double equalFloatsOrDoublesDelta() {
 239         return 0.0D;
 240     }
 241 
 242     // unless overridden ulpsDelta is not used
 243     protected int ulpsDelta() {
 244         return 0;
 245     }
 246 
 247     @SuppressWarnings("serial")
 248     public static class MultiCauseAssertionError extends AssertionError {
 249 
 250         private Throwable[] causes;
 251 
 252         public MultiCauseAssertionError(String message, Throwable... causes) {
 253             super(message);
 254             this.causes = causes;
 255         }
 256 
 257         @Override
 258         public void printStackTrace(PrintStream out) {
 259             super.printStackTrace(out);
 260             int num = 0;
 261             for (Throwable cause : causes) {
 262                 if (cause != null) {
 263                     out.print("cause " + (num++));
 264                     cause.printStackTrace(out);
 265                 }
 266             }
 267         }
 268 
 269         @Override
 270         public void printStackTrace(PrintWriter out) {
 271             super.printStackTrace(out);
 272             int num = 0;
 273             for (Throwable cause : causes) {
 274                 if (cause != null) {
 275                     out.print("cause " + (num++) + ": ");
 276                     cause.printStackTrace(out);
 277                 }
 278             }
 279         }
 280     }
 281 
 282     /*
 283      * Overrides to the normal JUnit {@link Assert} routines that provide varargs style formatting
 284      * and produce an exception stack trace with the assertion frames trimmed out.
 285      */
 286 
 287     /**
 288      * Fails a test with the given message.
 289      *
 290      * @param message the identifying message for the {@link AssertionError} (<code>null</code>
 291      *            okay)
 292      * @see AssertionError
 293      */
 294     public static void fail(String message, Object... objects) {
 295         AssertionError e;
 296         if (message == null) {
 297             e = new AssertionError();
 298         } else {
 299             e = new AssertionError(String.format(message, objects));
 300         }
 301         // Trim the assert frames from the stack trace
 302         StackTraceElement[] trace = e.getStackTrace();
 303         int start = 1; // Skip this frame
 304         String thisClassName = GraalTest.class.getName();
 305         while (start < trace.length && trace[start].getClassName().equals(thisClassName) && (trace[start].getMethodName().equals("assertTrue") || trace[start].getMethodName().equals("assertFalse"))) {
 306             start++;
 307         }
 308         e.setStackTrace(Arrays.copyOfRange(trace, start, trace.length));
 309         throw e;
 310     }
 311 
 312     /**
 313      * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} with the
 314      * given message.
 315      *
 316      * @param message the identifying message for the {@link AssertionError} (<code>null</code>
 317      *            okay)
 318      * @param condition condition to be checked
 319      */
 320     public static void assertTrue(String message, boolean condition) {
 321         assertTrue(condition, message);
 322     }
 323 
 324     /**
 325      * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} without a
 326      * message.
 327      *
 328      * @param condition condition to be checked
 329      */
 330     public static void assertTrue(boolean condition) {
 331         assertTrue(condition, null);
 332     }
 333 
 334     /**
 335      * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} with the
 336      * given message.
 337      *
 338      * @param message the identifying message for the {@link AssertionError} (<code>null</code>
 339      *            okay)
 340      * @param condition condition to be checked
 341      */
 342     public static void assertFalse(String message, boolean condition) {
 343         assertTrue(!condition, message);
 344     }
 345 
 346     /**
 347      * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} without a
 348      * message.
 349      *
 350      * @param condition condition to be checked
 351      */
 352     public static void assertFalse(boolean condition) {
 353         assertTrue(!condition, null);
 354     }
 355 
 356     /**
 357      * Asserts that a condition is true. If it isn't it throws an {@link AssertionError} with the
 358      * given message.
 359      *
 360      * @param condition condition to be checked
 361      * @param message the identifying message for the {@link AssertionError}
 362      * @param objects arguments to the format string
 363      */
 364     public static void assertTrue(boolean condition, String message, Object... objects) {
 365         if (!condition) {
 366             fail(message, objects);
 367         }
 368     }
 369 
 370     /**
 371      * Asserts that a condition is false. If it isn't it throws an {@link AssertionError} with the
 372      * given message produced by {@link String#format}.
 373      *
 374      * @param condition condition to be checked
 375      * @param message the identifying message for the {@link AssertionError}
 376      * @param objects arguments to the format string
 377      */
 378     public static void assertFalse(boolean condition, String message, Object... objects) {
 379         assertTrue(!condition, message, objects);
 380     }
 381 
 382     /**
 383      * Gets the {@link DebugHandlersFactory}s available for a {@link DebugContext}.
 384      */
 385     protected Collection<DebugHandlersFactory> getDebugHandlersFactories() {
 386         return Collections.emptyList();
 387     }
 388 
 389     /**
 390      * Gets a {@link DebugContext} object corresponding to {@code options}, creating a new one if
 391      * none currently exists. Debug contexts created by this method will have their
 392      * {@link DebugDumpHandler}s closed in {@link #afterTest()}.
 393      */
 394     protected DebugContext getDebugContext(OptionValues options) {
 395         List<DebugContext> cached = cachedDebugs.get();
 396         if (cached == null) {
 397             cached = new ArrayList<>();
 398             cachedDebugs.set(cached);
 399         }
 400         for (DebugContext debug : cached) {
 401             if (debug.getOptions() == options) {
 402                 return debug;
 403             }
 404         }
 405         DebugContext debug = DebugContext.create(options, NO_DESCRIPTION, NO_GLOBAL_METRIC_VALUES, DEFAULT_LOG_STREAM, getDebugHandlersFactories());
 406         cached.add(debug);
 407         return debug;
 408     }
 409 
 410     private final ThreadLocal<List<DebugContext>> cachedDebugs = new ThreadLocal<>();
 411 
 412     @After
 413     public void afterTest() {
 414         List<DebugContext> cached = cachedDebugs.get();
 415         if (cached != null) {
 416             for (DebugContext debug : cached) {
 417                 debug.closeDumpHandlers(true);
 418             }
 419         }
 420     }
 421 }