1 /* 2 * Copyright (c) 2009, 2012, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 package com.oracle.graal.compiler.hsail.test.infra; 24 25 import static org.junit.Assert.*; 26 import static org.junit.Assume.*; 27 28 import java.io.*; 29 import java.lang.annotation.*; 30 import java.lang.reflect.*; 31 import java.nio.file.*; 32 import java.util.*; 33 import java.util.concurrent.atomic.*; 34 import java.util.logging.*; 35 36 import com.amd.okra.*; 37 38 /** 39 * Abstract class on which the HSAIL unit tests are built. Executes a method or lambda on both the 40 * Java side and the Okra side and compares the results for fields that are annotated with 41 * {@link KernelTester.Result}. 42 */ 43 public abstract class KernelTester { 44 45 /** 46 * Denotes a field whose value is to be compared as part of computing the result of a test. 47 */ 48 @Retention(RetentionPolicy.RUNTIME) 49 @Target(ElementType.FIELD) 50 public @interface Result { 51 } 52 53 // Using these in case we want to compile with Java 7. 54 public interface MyIntConsumer { 55 56 void accept(int value); 57 } 58 59 public interface MyObjConsumer { 60 61 void accept(Object obj); 62 } 63 64 public enum DispatchMode { 65 SEQ, JTP, OKRA 66 } 67 68 public enum HsailMode { 69 COMPILED, INJECT_HSAIL, INJECT_OCL 70 } 71 72 private DispatchMode dispatchMode; 73 // Where the hsail comes from. 74 private HsailMode hsailMode; 75 private Method testMethod; 76 // What type of okra dispatch to use when client calls. 77 private boolean useLambdaMethod; 78 private Class<?>[] testMethodParams = null; 79 private int id = nextId.incrementAndGet(); 80 static AtomicInteger nextId = new AtomicInteger(0); 81 public static Logger logger; 82 private OkraContext okraContext; 83 private OkraKernel okraKernel; 84 private static final String propPkgName = KernelTester.class.getPackage().getName(); 85 private static Level logLevel; 86 private static ConsoleHandler consoleHandler; 87 88 static { 89 logger = Logger.getLogger(propPkgName); 90 logLevel = Level.parse(System.getProperty("kerneltester.logLevel", "SEVERE")); 91 92 // This block configure the logger with handler and formatter. 93 consoleHandler = new ConsoleHandler(); 94 logger.addHandler(consoleHandler); 95 logger.setUseParentHandlers(false); 96 SimpleFormatter formatter = new SimpleFormatter() { 97 98 @SuppressWarnings("sync-override") 99 @Override 100 public String format(LogRecord record) { 101 return (record.getMessage() + "\n"); 102 } 103 }; 104 consoleHandler.setFormatter(formatter); 105 setLogLevel(logLevel); 106 } 107 108 private static boolean gaveNoOkraWarning = false; 109 private boolean onSimulator; 110 private boolean okraLibExists; 111 112 public boolean runningOnSimulator() { 113 return onSimulator; 114 } 115 116 public KernelTester() { 117 okraLibExists = OkraUtil.okraLibExists(); 118 dispatchMode = DispatchMode.SEQ; 119 hsailMode = HsailMode.COMPILED; 120 useLambdaMethod = false; 121 } 122 123 public abstract void runTest(); 124 125 // Default comparison is to compare all things marked @Result. 126 public boolean compareResults(KernelTester base) { 127 Class<?> clazz = this.getClass(); 128 while (clazz != null && clazz != KernelTester.class) { 129 for (Field f : clazz.getDeclaredFields()) { 130 if (!Modifier.isStatic(f.getModifiers())) { 131 Result annos = f.getAnnotation(Result.class); 132 if (annos != null) { 133 logger.fine("@Result field = " + f); 134 Object myResult = getFieldFromObject(f, this); 135 Object otherResult = getFieldFromObject(f, base); 136 boolean same = compareObjects(myResult, otherResult); 137 logger.fine("comparing " + myResult + ", " + otherResult + ", match=" + same); 138 if (!same) { 139 logger.severe("mismatch comparing " + f + ", " + myResult + " vs. " + otherResult); 140 logSevere("FAILED!!! " + this.getClass()); 141 return false; 142 } 143 } 144 } 145 } 146 clazz = clazz.getSuperclass(); 147 } 148 logInfo("PASSED: " + this.getClass()); 149 return true; 150 } 151 152 private boolean compareObjects(Object first, Object second) { 153 Class<?> clazz = first.getClass(); 154 if (clazz != second.getClass()) { 155 return false; 156 } 157 if (!clazz.isArray()) { 158 // Non arrays. 159 if (clazz.equals(float.class) || clazz.equals(double.class)) { 160 return isEqualsFP((double) first, (double) second); 161 } else { 162 return first.equals(second); 163 } 164 } else { 165 // Handle the case where Objects are arrays. 166 ArrayComparer comparer; 167 if (clazz.equals(float[].class) || clazz.equals(double[].class)) { 168 comparer = new FPArrayComparer(); 169 } else if (clazz.equals(long[].class) || clazz.equals(int[].class) || clazz.equals(byte[].class)) { 170 comparer = new IntArrayComparer(); 171 } else if (clazz.equals(boolean[].class)) { 172 comparer = new BooleanArrayComparer(); 173 } else { 174 comparer = new ObjArrayComparer(); 175 } 176 return comparer.compareArrays(first, second); 177 } 178 } 179 180 static final int MISMATCHLIMIT = 10; 181 static final int ELEMENTDISPLAYLIMIT = 20; 182 183 public int getMisMatchLimit() { 184 return MISMATCHLIMIT; 185 } 186 187 public int getElementDisplayLimit() { 188 return ELEMENTDISPLAYLIMIT; 189 } 190 191 abstract class ArrayComparer { 192 193 abstract Object getElement(Object ary, int index); 194 195 // Equality test, can be overridden 196 boolean isEquals(Object firstElement, Object secondElement) { 197 return firstElement.equals(secondElement); 198 } 199 200 boolean compareArrays(Object first, Object second) { 201 int len = Array.getLength(first); 202 if (len != Array.getLength(second)) { 203 return false; 204 } 205 // If info logLevel, build string of first few elements from first array. 206 if (logLevel.intValue() <= Level.INFO.intValue()) { 207 StringBuilder sb = new StringBuilder(); 208 for (int i = 0; i < Math.min(len, getElementDisplayLimit()); i++) { 209 sb.append(getElement(first, i)); 210 sb.append(", "); 211 } 212 logger.info(sb.toString()); 213 } 214 boolean success = true; 215 int mismatches = 0; 216 for (int i = 0; i < len; i++) { 217 Object firstElement = getElement(first, i); 218 Object secondElement = getElement(second, i); 219 if (!isEquals(firstElement, secondElement)) { 220 logSevere("mismatch at index " + i + ", expected " + secondElement + ", saw " + firstElement); 221 success = false; 222 mismatches++; 223 if (mismatches >= getMisMatchLimit()) { 224 logSevere("...Truncated"); 225 break; 226 } 227 } 228 } 229 return success; 230 } 231 } 232 233 class FPArrayComparer extends ArrayComparer { 234 235 @Override 236 Object getElement(Object ary, int index) { 237 return Array.getDouble(ary, index); 238 } 239 240 @Override 241 boolean isEquals(Object firstElement, Object secondElement) { 242 return isEqualsFP((double) firstElement, (double) secondElement); 243 } 244 } 245 246 class IntArrayComparer extends ArrayComparer { 247 248 @Override 249 Object getElement(Object ary, int index) { 250 return Array.getLong(ary, index); 251 } 252 } 253 254 class BooleanArrayComparer extends ArrayComparer { 255 256 @Override 257 Object getElement(Object ary, int index) { 258 return Array.getBoolean(ary, index); 259 } 260 } 261 262 class ObjArrayComparer extends ArrayComparer { 263 264 @Override 265 Object getElement(Object ary, int index) { 266 return Array.get(ary, index); 267 } 268 } 269 270 /** 271 * This isEqualsFP method allows subclass to override what FP equality means for this particular 272 * unit test. 273 */ 274 public boolean isEqualsFP(double first, double second) { 275 return first == second; 276 } 277 278 public void setDispatchMode(DispatchMode dispatchMode) { 279 this.dispatchMode = dispatchMode; 280 } 281 282 public void setHsailMode(HsailMode hsailMode) { 283 this.hsailMode = hsailMode; 284 } 285 286 /** 287 * Return a clone of this instance unless overridden, we just call the null constructor. 288 */ 289 public KernelTester newInstance() { 290 try { 291 return this.getClass().getConstructor((Class<?>[]) null).newInstance(); 292 } catch (Throwable t) { 293 fail("Unexpected exception " + t); 294 return null; 295 } 296 } 297 298 public Method getMethodFromMethodName(String methName, Class<?> clazz) { 299 Class<?> clazz2 = clazz; 300 while (clazz2 != null) { 301 for (Method m : clazz2.getDeclaredMethods()) { 302 logger.fine(" in " + clazz2 + ", trying to match " + m); 303 if (m.getName().equals(methName)) { 304 testMethodParams = m.getParameterTypes(); 305 if (logLevel.intValue() <= Level.FINE.intValue()) { 306 logger.fine(" in " + clazz2 + ", matched " + m); 307 logger.fine("parameter types are..."); 308 int paramNum = 0; 309 for (Class<?> pclazz : testMethodParams) { 310 logger.fine(paramNum++ + ") " + pclazz.toString()); 311 } 312 } 313 return m; 314 } 315 } 316 // Didn't find it in current clazz, try superclass. 317 clazz2 = clazz2.getSuperclass(); 318 } 319 // If we got this far, no match. 320 return null; 321 } 322 323 private void setTestMethod(String methName, Class<?> inClazz) { 324 testMethod = getMethodFromMethodName(methName, inClazz); 325 if (testMethod == null) { 326 fail("cannot find method " + methName + " in class " + inClazz); 327 } else { 328 // Print info but only for first such class. 329 if (id == 1) { 330 logger.fine("testMethod to be compiled is \n " + testMethod); 331 } 332 } 333 } 334 335 // Default is method name "run", but could be overridden. 336 private final String defaultMethodName = "run"; 337 338 public String getTestMethodName() { 339 return defaultMethodName; 340 } 341 342 /** 343 * The dispatchMethodKernel dispatches a non-lambda method. All the parameters of the compiled 344 * method are supplied as parameters to this call. 345 */ 346 public void dispatchMethodKernel(int range, Object... args) { 347 if (testMethod == null) { 348 setTestMethod(getTestMethodName(), this.getClass()); 349 } 350 if (dispatchMode == DispatchMode.SEQ) { 351 dispatchMethodKernelSeq(range, args); 352 } else if (dispatchMode == DispatchMode.OKRA) { 353 dispatchMethodKernelOkra(range, args); 354 } 355 } 356 357 /** 358 * This dispatchLambdaMethodKernel dispatches the lambda version of a kernel where the "kernel" 359 * is for the lambda method itself (like lambda$0). 360 */ 361 public void dispatchLambdaMethodKernel(int range, MyIntConsumer consumer) { 362 if (testMethod == null) { 363 setTestMethod(findLambdaMethodName(), this.getClass()); 364 } 365 if (dispatchMode == DispatchMode.SEQ) { 366 dispatchLambdaKernelSeq(range, consumer); 367 } else if (dispatchMode == DispatchMode.OKRA) { 368 dispatchLambdaMethodKernelOkra(range, consumer); 369 } 370 } 371 372 public void dispatchLambdaMethodKernel(Object[] ary, MyObjConsumer consumer) { 373 if (testMethod == null) { 374 setTestMethod(findLambdaMethodName(), this.getClass()); 375 } 376 if (dispatchMode == DispatchMode.SEQ) { 377 dispatchLambdaKernelSeq(ary, consumer); 378 } else if (dispatchMode == DispatchMode.OKRA) { 379 dispatchLambdaMethodKernelOkra(ary, consumer); 380 } 381 } 382 383 /** 384 * The dispatchLambdaKernel dispatches the lambda version of a kernel where the "kernel" is for 385 * the xxx$$Lambda.accept method in the wrapper for the lambda. Note that the useLambdaMethod 386 * boolean provides a way of actually invoking dispatchLambdaMethodKernel from this API. 387 */ 388 public void dispatchLambdaKernel(int range, MyIntConsumer consumer) { 389 if (useLambdaMethod) { 390 dispatchLambdaMethodKernel(range, consumer); 391 return; 392 } 393 if (testMethod == null) { 394 setTestMethod("accept", consumer.getClass()); 395 } 396 if (dispatchMode == DispatchMode.SEQ) { 397 dispatchLambdaKernelSeq(range, consumer); 398 } else if (dispatchMode == DispatchMode.OKRA) { 399 dispatchLambdaKernelOkra(range, consumer); 400 } 401 } 402 403 public void dispatchLambdaKernel(Object[] ary, MyObjConsumer consumer) { 404 if (useLambdaMethod) { 405 dispatchLambdaMethodKernel(ary, consumer); 406 return; 407 } 408 if (testMethod == null) { 409 setTestMethod("accept", consumer.getClass()); 410 } 411 if (dispatchMode == DispatchMode.SEQ) { 412 dispatchLambdaKernelSeq(ary, consumer); 413 } else if (dispatchMode == DispatchMode.OKRA) { 414 dispatchLambdaKernelOkra(ary, consumer); 415 } 416 } 417 418 private ArrayList<String> getLambdaMethodNames() { 419 Class<?> clazz = this.getClass(); 420 ArrayList<String> lambdaNames = new ArrayList<>(); 421 while (clazz != null && (lambdaNames.size() == 0)) { 422 for (Method m : clazz.getDeclaredMethods()) { 423 logger.fine(" in " + clazz + ", trying to match " + m); 424 if (m.getName().startsWith("lambda$")) { 425 lambdaNames.add(m.getName()); 426 } 427 } 428 // Didn't find it in current clazz, try superclass. 429 clazz = clazz.getSuperclass(); 430 } 431 return lambdaNames; 432 } 433 434 /** 435 * findLambdaMethodName finds a name in the class starting with lambda$. If we find more than 436 * one, throw an error, and tell user to override explicitly 437 */ 438 private String findLambdaMethodName() { 439 // If user overrode getTestMethodName, use that name. 440 if (!getTestMethodName().equals(defaultMethodName)) { 441 return getTestMethodName(); 442 } else { 443 ArrayList<String> lambdaNames = getLambdaMethodNames(); 444 switch (lambdaNames.size()) { 445 case 1: 446 return lambdaNames.get(0); 447 case 0: 448 fail("No lambda method found in " + this.getClass()); 449 return null; 450 default: 451 // More than one lambda. 452 String msg = "Multiple lambda methods found in " + this.getClass() + "\nYou should override getTestMethodName with one of the following\n"; 453 for (String name : lambdaNames) { 454 msg = msg + name + "\n"; 455 } 456 fail(msg); 457 return null; 458 } 459 } 460 } 461 462 /** 463 * The getCompiledHSAILSource returns the string of HSAIL code for the compiled method. By 464 * default, throws an error. In graal for instance, this would be overridden in 465 * GraalKernelTester. 466 */ 467 public String getCompiledHSAILSource(Method testMethod1) { 468 fail("no compiler connected so unable to compile " + testMethod1 + "\nYou could try injecting HSAIL or OpenCL"); 469 return null; 470 } 471 472 public String getHSAILSource(Method testMethod1) { 473 switch (hsailMode) { 474 case COMPILED: 475 return getCompiledHSAILSource(testMethod1); 476 case INJECT_HSAIL: 477 return getHsailFromClassnameHsailFile(); 478 case INJECT_OCL: 479 return getHsailFromClassnameOclFile(); 480 default: 481 fail("unknown hsailMode = " + hsailMode); 482 return null; 483 } 484 } 485 486 /** 487 * The getHSAILKernelName returns the name of the hsail kernel. By default we use 'run'. unless 488 * coming from opencl injection. Could be overridden by the junit test. 489 */ 490 public String getHSAILKernelName() { 491 return (hsailMode != HsailMode.INJECT_OCL ? "&run" : "&__OpenCL_run_kernel"); 492 } 493 494 private void createOkraKernel() { 495 // Call routines in the derived class to get the hsail code and kernel name. 496 String hsailSource = getHSAILSource(testMethod); 497 if (!okraLibExists) { 498 if (!gaveNoOkraWarning) { 499 logger.fine("No Okra library detected, skipping all KernelTester tests in " + this.getClass().getPackage().getName()); 500 gaveNoOkraWarning = true; 501 } 502 } 503 // Ignore any kerneltester test if okra does not exist. 504 assumeTrue(okraLibExists); 505 // Control which okra instances can run the tests. 506 onSimulator = OkraContext.isSimulator(); 507 okraContext = new OkraContext(); 508 if (!okraContext.isValid()) { 509 fail("...unable to create context"); 510 } 511 // Control verbosity in okra from our logLevel. 512 if (logLevel.intValue() <= Level.INFO.intValue()) { 513 okraContext.setVerbose(true); 514 } 515 okraKernel = new OkraKernel(okraContext, hsailSource, getHSAILKernelName()); 516 if (!okraKernel.isValid()) { 517 fail("...unable to create kernel"); 518 } 519 } 520 521 private void dispatchKernelOkra(int range, Object... args) { 522 if (okraKernel == null) { 523 createOkraKernel(); 524 } 525 if (logLevel.intValue() <= Level.FINE.intValue()) { 526 logger.fine("Arguments passed to okra..."); 527 for (Object arg : args) { 528 logger.fine(" " + arg); 529 } 530 } 531 okraKernel.setLaunchAttributes(range); 532 okraKernel.dispatchWithArgs(args); 533 } 534 535 private void dispatchMethodKernelSeq(int range, Object... args) { 536 Object[] invokeArgs = new Object[args.length + 1]; 537 // Need space on the end for the gid parameter. 538 System.arraycopy(args, 0, invokeArgs, 0, args.length); 539 int gidArgIndex = invokeArgs.length - 1; 540 if (logLevel.intValue() <= Level.FINE.intValue()) { 541 for (Object arg : args) { 542 logger.fine(arg.toString()); 543 } 544 } 545 for (int rangeIndex = 0; rangeIndex < range; rangeIndex++) { 546 invokeArgs[gidArgIndex] = rangeIndex; 547 try { 548 testMethod.invoke(this, invokeArgs); 549 } catch (IllegalAccessException e) { 550 fail("could not invoke " + testMethod + ", make sure it is public"); 551 } catch (IllegalArgumentException e) { 552 fail("wrong arguments invoking " + testMethod + ", check number and type of args passed to dispatchMethodKernel"); 553 } catch (InvocationTargetException e) { 554 Throwable cause = e.getCause(); 555 /** 556 * We will ignore ArrayIndexOutOfBoundsException because the graal okra target 557 * doesn't really handle it yet (basically returns early if it sees one). 558 */ 559 if (cause instanceof ArrayIndexOutOfBoundsException) { 560 logger.severe("ignoring ArrayIndexOutOfBoundsException for index " + rangeIndex); 561 } else { 562 // Other exceptions. 563 String errstr = testMethod + " threw an exception on gid=" + rangeIndex + ", exception was " + cause; 564 fail(errstr); 565 } 566 } catch (Exception e) { 567 fail("Unknown exception " + e + " invoking " + testMethod); 568 } 569 } 570 } 571 572 private void dispatchMethodKernelOkra(int range, Object... args) { 573 Object[] fixedArgs = fixArgTypes(args); 574 if (Modifier.isStatic(testMethod.getModifiers())) { 575 dispatchKernelOkra(range, fixedArgs); 576 } else { 577 // If it is a non-static method we have to push "this" as the first argument. 578 Object[] newFixedArgs = new Object[fixedArgs.length + 1]; 579 System.arraycopy(fixedArgs, 0, newFixedArgs, 1, fixedArgs.length); 580 newFixedArgs[0] = this; 581 dispatchKernelOkra(range, newFixedArgs); 582 } 583 } 584 585 /** 586 * For primitive arg parameters, make sure arg types are cast to whatever the testMethod 587 * signature says they should be. 588 */ 589 private Object[] fixArgTypes(Object[] args) { 590 Object[] fixedArgs = new Object[args.length]; 591 for (int i = 0; i < args.length; i++) { 592 Class<?> paramClass = testMethodParams[i]; 593 if (paramClass.equals(Float.class) || paramClass.equals(float.class)) { 594 fixedArgs[i] = ((Number) args[i]).floatValue(); 595 } else if (paramClass.equals(Integer.class) || paramClass.equals(int.class)) { 596 fixedArgs[i] = ((Number) args[i]).intValue(); 597 } else if (paramClass.equals(Long.class) || paramClass.equals(long.class)) { 598 fixedArgs[i] = ((Number) args[i]).longValue(); 599 } else if (paramClass.equals(Double.class) || paramClass.equals(double.class)) { 600 fixedArgs[i] = ((Number) args[i]).doubleValue(); 601 } else if (paramClass.equals(Byte.class) || paramClass.equals(byte.class)) { 602 fixedArgs[i] = ((Number) args[i]).byteValue(); 603 } else if (paramClass.equals(Boolean.class) || paramClass.equals(boolean.class)) { 604 fixedArgs[i] = (boolean) args[i]; 605 } else { 606 // All others just move unchanged. 607 fixedArgs[i] = args[i]; 608 } 609 } 610 return fixedArgs; 611 } 612 613 /** 614 * Dispatching a lambda on the java side is simple. 615 */ 616 @SuppressWarnings("static-method") 617 private void dispatchLambdaKernelSeq(int range, MyIntConsumer consumer) { 618 for (int i = 0; i < range; i++) { 619 consumer.accept(i); 620 } 621 } 622 623 @SuppressWarnings("static-method") 624 private void dispatchLambdaKernelSeq(Object[] ary, MyObjConsumer consumer) { 625 for (Object obj : ary) { 626 consumer.accept(obj); 627 } 628 } 629 630 /** 631 * The dispatchLambdaMethodKernelOkra dispatches in the case where the hsail kernel implements 632 * the lambda method itself as opposed to the wrapper that calls the lambda method. From the 633 * consumer object, we need to find the fields and pass them to the kernel. 634 */ 635 private void dispatchLambdaMethodKernelOkra(int range, MyIntConsumer consumer) { 636 logger.info("To determine parameters to pass to hsail kernel, we will examine " + consumer.getClass()); 637 Field[] fields = consumer.getClass().getDeclaredFields(); 638 Object[] args = new Object[fields.length]; 639 int argIndex = 0; 640 for (Field f : fields) { 641 logger.info("... " + f); 642 args[argIndex++] = getFieldFromObject(f, consumer); 643 } 644 dispatchKernelOkra(range, args); 645 } 646 647 private void dispatchLambdaMethodKernelOkra(Object[] ary, MyObjConsumer consumer) { 648 logger.info("To determine parameters to pass to hsail kernel, we will examine " + consumer.getClass()); 649 Field[] fields = consumer.getClass().getDeclaredFields(); 650 Object[] args = new Object[fields.length]; 651 int argIndex = 0; 652 for (Field f : fields) { 653 logger.info("... " + f); 654 args[argIndex++] = getFieldFromObject(f, consumer); 655 } 656 dispatchKernelOkra(ary.length, args); 657 } 658 659 /** 660 * The dispatchLambdaKernelOkra dispatches in the case where the hsail kernel where the hsail 661 * kernel implements the accept method of the wrapper that calls the lambda method as opposed to 662 * the actual lambda method itself. 663 */ 664 private void dispatchLambdaKernelOkra(int range, MyIntConsumer consumer) { 665 // The "wrapper" method always has only one arg consisting of the consumer. 666 Object[] args = new Object[1]; 667 args[0] = consumer; 668 dispatchKernelOkra(range, args); 669 } 670 671 private void dispatchLambdaKernelOkra(Object[] ary, MyObjConsumer consumer) { 672 // The "wrapper" method always has only one arg consisting of the consumer. 673 Object[] args = new Object[1]; 674 args[0] = consumer; 675 dispatchKernelOkra(ary.length, args); 676 } 677 678 private void disposeKernelOkra() { 679 if (okraContext != null) { 680 okraContext.dispose(); 681 } 682 } 683 684 private void compareOkraToSeq(HsailMode hsailMode1) { 685 compareOkraToSeq(hsailMode1, false); 686 } 687 688 /** 689 * Runs this instance on OKRA, and as SEQ and compares the output of the two executions. 690 */ 691 private void compareOkraToSeq(HsailMode hsailMode1, boolean useLambda) { 692 // Create and run sequential instance. 693 KernelTester testerSeq = newInstance(); 694 testerSeq.setDispatchMode(DispatchMode.SEQ); 695 testerSeq.runTest(); 696 // Now do this object. 697 this.setHsailMode(hsailMode1); 698 this.setDispatchMode(DispatchMode.OKRA); 699 this.useLambdaMethod = useLambda; 700 this.runTest(); 701 this.disposeKernelOkra(); 702 assertTrue("failed comparison to SEQ", compareResults(testerSeq)); 703 } 704 705 public void testGeneratedHsail() { 706 compareOkraToSeq(HsailMode.COMPILED); 707 } 708 709 public void testGeneratedHsailUsingLambdaMethod() { 710 compareOkraToSeq(HsailMode.COMPILED, true); 711 } 712 713 public void testInjectedHsail() { 714 newInstance().compareOkraToSeq(HsailMode.INJECT_HSAIL); 715 } 716 717 public void testInjectedOpencl() { 718 newInstance().compareOkraToSeq(HsailMode.INJECT_OCL); 719 } 720 721 private static Object getFieldFromObject(Field f, Object fromObj) { 722 try { 723 f.setAccessible(true); 724 Type type = f.getType(); 725 logger.info("type = " + type); 726 if (type == double.class) { 727 return f.getDouble(fromObj); 728 } else if (type == float.class) { 729 return f.getFloat(fromObj); 730 } else if (type == long.class) { 731 return f.getLong(fromObj); 732 } else if (type == int.class) { 733 return f.getInt(fromObj); 734 } else if (type == byte.class) { 735 return f.getByte(fromObj); 736 } else if (type == boolean.class) { 737 return f.getBoolean(fromObj); 738 } else { 739 return f.get(fromObj); 740 } 741 } catch (Exception e) { 742 fail("unable to get field " + f + " from " + fromObj); 743 return null; 744 } 745 } 746 747 public static void checkFileExists(String fileName) { 748 assertTrue(fileName + " does not exist", fileExists(fileName)); 749 } 750 751 public static boolean fileExists(String fileName) { 752 return new File(fileName).exists(); 753 } 754 755 public static String getFileAsString(String sourceFileName) { 756 String source = null; 757 try { 758 checkFileExists(sourceFileName); 759 source = new String(Files.readAllBytes(FileSystems.getDefault().getPath(sourceFileName))); 760 } catch (IOException e) { 761 fail("could not open file " + sourceFileName); 762 return null; 763 } 764 return source; 765 } 766 767 public static String getHsailFromFile(String sourceFileName) { 768 logger.severe("... getting hsail from file " + sourceFileName); 769 return getFileAsString(sourceFileName); 770 } 771 772 private static void executeCmd(String... cmd) { 773 logger.info("spawning" + Arrays.toString(cmd)); 774 try { 775 ProcessBuilder pb = new ProcessBuilder(cmd); 776 Process p = pb.start(); 777 if (logLevel.intValue() <= Level.INFO.intValue()) { 778 InputStream in = p.getInputStream(); 779 BufferedInputStream buf = new BufferedInputStream(in); 780 InputStreamReader inread = new InputStreamReader(buf); 781 BufferedReader bufferedreader = new BufferedReader(inread); 782 String line; 783 while ((line = bufferedreader.readLine()) != null) { 784 logger.info(line); 785 } 786 } 787 p.waitFor(); 788 } catch (Exception e) { 789 fail("could not execute <" + Arrays.toString(cmd) + ">"); 790 } 791 } 792 793 public static String getHsailFromOpenCLFile(String openclFileName) { 794 String openclHsailFile = "opencl_out.hsail"; 795 String tmpTahitiFile = "_temp_0_Tahiti.txt"; 796 checkFileExists(openclFileName); 797 logger.severe("...converting " + openclFileName + " to HSAIL..."); 798 executeCmd("aoc2", "-m64", "-I./", "-march=hsail", openclFileName); 799 if (fileExists(tmpTahitiFile)) { 800 return getFileAsString(tmpTahitiFile); 801 } else { 802 executeCmd("HSAILasm", "-disassemble", "-o", openclHsailFile, openclFileName.replace(".cl", ".bin")); 803 checkFileExists(openclHsailFile); 804 return getFileAsString(openclHsailFile); 805 } 806 } 807 808 public String getHsailFromClassnameHsailFile() { 809 return (getHsailFromFile(this.getClass().getSimpleName() + ".hsail")); 810 } 811 812 public String getHsailFromClassnameOclFile() { 813 return (getHsailFromOpenCLFile(this.getClass().getSimpleName() + ".cl")); 814 } 815 816 public static void logInfo(String msg) { 817 logger.info(msg); 818 } 819 820 public static void logSevere(String msg) { 821 logger.severe(msg); 822 } 823 824 public static void setLogLevel(Level level) { 825 logLevel = level; 826 logger.setLevel(level); 827 consoleHandler.setLevel(level); 828 } 829 }