1 /*
2 * Copyright (c) 2009, 2012, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23 package com.oracle.graal.compiler.hsail.test.infra;
24
25 import static org.junit.Assert.*;
26 import static org.junit.Assume.*;
27
28 import java.io.*;
29 import java.lang.annotation.*;
30 import java.lang.reflect.*;
31 import java.nio.file.*;
32 import java.util.*;
33 import java.util.concurrent.atomic.*;
34 import java.util.logging.*;
35
36 import com.amd.okra.*;
37
38 /**
39 * Abstract class on which the HSAIL unit tests are built. Executes a method or lambda on both the
40 * Java side and the Okra side and compares the results for fields that are annotated with
41 * {@link KernelTester.Result}.
42 */
43 public abstract class KernelTester {
44
45 /**
46 * Denotes a field whose value is to be compared as part of computing the result of a test.
47 */
48 @Retention(RetentionPolicy.RUNTIME)
49 @Target(ElementType.FIELD)
50 public @interface Result {
51 }
52
53 // Using these in case we want to compile with Java 7.
54 public interface MyIntConsumer {
55
56 void accept(int value);
57 }
58
59 public interface MyObjConsumer {
60
61 void accept(Object obj);
62 }
63
64 public enum DispatchMode {
65 SEQ, JTP, OKRA
66 }
67
68 public enum HsailMode {
69 COMPILED, INJECT_HSAIL, INJECT_OCL
70 }
71
72 private DispatchMode dispatchMode;
73 // Where the hsail comes from.
74 private HsailMode hsailMode;
75 private Method testMethod;
76 // What type of okra dispatch to use when client calls.
77 private boolean useLambdaMethod;
78 private Class<?>[] testMethodParams = null;
79 private int id = nextId.incrementAndGet();
80 static AtomicInteger nextId = new AtomicInteger(0);
81 public static Logger logger;
82 private OkraContext okraContext;
83 private OkraKernel okraKernel;
84 private static final String propPkgName = KernelTester.class.getPackage().getName();
85 private static Level logLevel;
86 private static ConsoleHandler consoleHandler;
87
88 static {
89 logger = Logger.getLogger(propPkgName);
90 logLevel = Level.parse(System.getProperty("kerneltester.logLevel", "SEVERE"));
91
92 // This block configure the logger with handler and formatter.
93 consoleHandler = new ConsoleHandler();
94 logger.addHandler(consoleHandler);
95 logger.setUseParentHandlers(false);
96 SimpleFormatter formatter = new SimpleFormatter() {
97
98 @SuppressWarnings("sync-override")
99 @Override
100 public String format(LogRecord record) {
101 return (record.getMessage() + "\n");
102 }
103 };
104 consoleHandler.setFormatter(formatter);
105 setLogLevel(logLevel);
106 }
107
108 private static boolean gaveNoOkraWarning = false;
109 private boolean onSimulator;
110 private boolean okraLibExists;
111
112 public boolean runningOnSimulator() {
113 return onSimulator;
114 }
115
116 public KernelTester() {
117 okraLibExists = OkraUtil.okraLibExists();
118 dispatchMode = DispatchMode.SEQ;
119 hsailMode = HsailMode.COMPILED;
120 useLambdaMethod = false;
121 }
122
123 public abstract void runTest();
124
125 // Default comparison is to compare all things marked @Result.
126 public boolean compareResults(KernelTester base) {
127 Class<?> clazz = this.getClass();
128 while (clazz != null && clazz != KernelTester.class) {
129 for (Field f : clazz.getDeclaredFields()) {
130 if (!Modifier.isStatic(f.getModifiers())) {
131 Result annos = f.getAnnotation(Result.class);
132 if (annos != null) {
133 logger.fine("@Result field = " + f);
134 Object myResult = getFieldFromObject(f, this);
135 Object otherResult = getFieldFromObject(f, base);
136 boolean same = compareObjects(myResult, otherResult);
137 logger.fine("comparing " + myResult + ", " + otherResult + ", match=" + same);
138 if (!same) {
139 logger.severe("mismatch comparing " + f + ", " + myResult + " vs. " + otherResult);
140 logSevere("FAILED!!! " + this.getClass());
141 return false;
142 }
143 }
144 }
145 }
146 clazz = clazz.getSuperclass();
147 }
148 logInfo("PASSED: " + this.getClass());
149 return true;
150 }
151
152 private boolean compareObjects(Object first, Object second) {
153 Class<?> clazz = first.getClass();
154 if (clazz != second.getClass()) {
155 return false;
156 }
157 if (!clazz.isArray()) {
158 // Non arrays.
159 if (clazz.equals(float.class) || clazz.equals(double.class)) {
160 return isEqualsFP((double) first, (double) second);
161 } else {
162 return first.equals(second);
163 }
164 } else {
165 // Handle the case where Objects are arrays.
166 ArrayComparer comparer;
167 if (clazz.equals(float[].class) || clazz.equals(double[].class)) {
168 comparer = new FPArrayComparer();
169 } else if (clazz.equals(long[].class) || clazz.equals(int[].class) || clazz.equals(byte[].class)) {
170 comparer = new IntArrayComparer();
171 } else if (clazz.equals(boolean[].class)) {
172 comparer = new BooleanArrayComparer();
173 } else {
174 comparer = new ObjArrayComparer();
175 }
176 return comparer.compareArrays(first, second);
177 }
178 }
179
180 static final int MISMATCHLIMIT = 10;
181 static final int ELEMENTDISPLAYLIMIT = 20;
182
183 public int getMisMatchLimit() {
184 return MISMATCHLIMIT;
185 }
186
187 public int getElementDisplayLimit() {
188 return ELEMENTDISPLAYLIMIT;
189 }
190
191 abstract class ArrayComparer {
192
193 abstract Object getElement(Object ary, int index);
194
195 // Equality test, can be overridden
196 boolean isEquals(Object firstElement, Object secondElement) {
197 return firstElement.equals(secondElement);
198 }
199
200 boolean compareArrays(Object first, Object second) {
201 int len = Array.getLength(first);
202 if (len != Array.getLength(second)) {
203 return false;
204 }
205 // If info logLevel, build string of first few elements from first array.
206 if (logLevel.intValue() <= Level.INFO.intValue()) {
207 StringBuilder sb = new StringBuilder();
208 for (int i = 0; i < Math.min(len, getElementDisplayLimit()); i++) {
209 sb.append(getElement(first, i));
210 sb.append(", ");
211 }
212 logger.info(sb.toString());
213 }
214 boolean success = true;
215 int mismatches = 0;
216 for (int i = 0; i < len; i++) {
217 Object firstElement = getElement(first, i);
218 Object secondElement = getElement(second, i);
219 if (!isEquals(firstElement, secondElement)) {
220 logSevere("mismatch at index " + i + ", expected " + secondElement + ", saw " + firstElement);
221 success = false;
222 mismatches++;
223 if (mismatches >= getMisMatchLimit()) {
224 logSevere("...Truncated");
225 break;
226 }
227 }
228 }
229 return success;
230 }
231 }
232
233 class FPArrayComparer extends ArrayComparer {
234
235 @Override
236 Object getElement(Object ary, int index) {
237 return Array.getDouble(ary, index);
238 }
239
240 @Override
241 boolean isEquals(Object firstElement, Object secondElement) {
242 return isEqualsFP((double) firstElement, (double) secondElement);
243 }
244 }
245
246 class IntArrayComparer extends ArrayComparer {
247
248 @Override
249 Object getElement(Object ary, int index) {
250 return Array.getLong(ary, index);
251 }
252 }
253
254 class BooleanArrayComparer extends ArrayComparer {
255
256 @Override
257 Object getElement(Object ary, int index) {
258 return Array.getBoolean(ary, index);
259 }
260 }
261
262 class ObjArrayComparer extends ArrayComparer {
263
264 @Override
265 Object getElement(Object ary, int index) {
266 return Array.get(ary, index);
267 }
268 }
269
270 /**
271 * This isEqualsFP method allows subclass to override what FP equality means for this particular
272 * unit test.
273 */
274 public boolean isEqualsFP(double first, double second) {
275 return first == second;
276 }
277
278 public void setDispatchMode(DispatchMode dispatchMode) {
279 this.dispatchMode = dispatchMode;
280 }
281
282 public void setHsailMode(HsailMode hsailMode) {
283 this.hsailMode = hsailMode;
284 }
285
286 /**
287 * Return a clone of this instance unless overridden, we just call the null constructor.
288 */
289 public KernelTester newInstance() {
290 try {
291 return this.getClass().getConstructor((Class<?>[]) null).newInstance();
292 } catch (Throwable t) {
293 fail("Unexpected exception " + t);
294 return null;
295 }
296 }
297
298 public Method getMethodFromMethodName(String methName, Class<?> clazz) {
299 Class<?> clazz2 = clazz;
300 while (clazz2 != null) {
301 for (Method m : clazz2.getDeclaredMethods()) {
302 logger.fine(" in " + clazz2 + ", trying to match " + m);
303 if (m.getName().equals(methName)) {
304 testMethodParams = m.getParameterTypes();
305 if (logLevel.intValue() <= Level.FINE.intValue()) {
306 logger.fine(" in " + clazz2 + ", matched " + m);
307 logger.fine("parameter types are...");
308 int paramNum = 0;
309 for (Class<?> pclazz : testMethodParams) {
310 logger.fine(paramNum++ + ") " + pclazz.toString());
311 }
312 }
313 return m;
314 }
315 }
316 // Didn't find it in current clazz, try superclass.
317 clazz2 = clazz2.getSuperclass();
318 }
319 // If we got this far, no match.
320 return null;
321 }
322
323 private void setTestMethod(String methName, Class<?> inClazz) {
324 testMethod = getMethodFromMethodName(methName, inClazz);
325 if (testMethod == null) {
326 fail("cannot find method " + methName + " in class " + inClazz);
327 } else {
328 // Print info but only for first such class.
329 if (id == 1) {
330 logger.fine("testMethod to be compiled is \n " + testMethod);
331 }
332 }
333 }
334
335 // Default is method name "run", but could be overridden.
336 private final String defaultMethodName = "run";
337
338 public String getTestMethodName() {
339 return defaultMethodName;
340 }
341
342 /**
343 * The dispatchMethodKernel dispatches a non-lambda method. All the parameters of the compiled
344 * method are supplied as parameters to this call.
345 */
346 public void dispatchMethodKernel(int range, Object... args) {
347 if (testMethod == null) {
348 setTestMethod(getTestMethodName(), this.getClass());
349 }
350 if (dispatchMode == DispatchMode.SEQ) {
351 dispatchMethodKernelSeq(range, args);
352 } else if (dispatchMode == DispatchMode.OKRA) {
353 dispatchMethodKernelOkra(range, args);
354 }
355 }
356
357 /**
358 * This dispatchLambdaMethodKernel dispatches the lambda version of a kernel where the "kernel"
359 * is for the lambda method itself (like lambda$0).
360 */
361 public void dispatchLambdaMethodKernel(int range, MyIntConsumer consumer) {
362 if (testMethod == null) {
363 setTestMethod(findLambdaMethodName(), this.getClass());
364 }
365 if (dispatchMode == DispatchMode.SEQ) {
366 dispatchLambdaKernelSeq(range, consumer);
367 } else if (dispatchMode == DispatchMode.OKRA) {
368 dispatchLambdaMethodKernelOkra(range, consumer);
369 }
370 }
371
372 public void dispatchLambdaMethodKernel(Object[] ary, MyObjConsumer consumer) {
373 if (testMethod == null) {
374 setTestMethod(findLambdaMethodName(), this.getClass());
375 }
376 if (dispatchMode == DispatchMode.SEQ) {
377 dispatchLambdaKernelSeq(ary, consumer);
378 } else if (dispatchMode == DispatchMode.OKRA) {
379 dispatchLambdaMethodKernelOkra(ary, consumer);
380 }
381 }
382
383 /**
384 * The dispatchLambdaKernel dispatches the lambda version of a kernel where the "kernel" is for
385 * the xxx$$Lambda.accept method in the wrapper for the lambda. Note that the useLambdaMethod
386 * boolean provides a way of actually invoking dispatchLambdaMethodKernel from this API.
387 */
388 public void dispatchLambdaKernel(int range, MyIntConsumer consumer) {
389 if (useLambdaMethod) {
390 dispatchLambdaMethodKernel(range, consumer);
391 return;
392 }
393 if (testMethod == null) {
394 setTestMethod("accept", consumer.getClass());
395 }
396 if (dispatchMode == DispatchMode.SEQ) {
397 dispatchLambdaKernelSeq(range, consumer);
398 } else if (dispatchMode == DispatchMode.OKRA) {
399 dispatchLambdaKernelOkra(range, consumer);
400 }
401 }
402
403 public void dispatchLambdaKernel(Object[] ary, MyObjConsumer consumer) {
404 if (useLambdaMethod) {
405 dispatchLambdaMethodKernel(ary, consumer);
406 return;
407 }
408 if (testMethod == null) {
409 setTestMethod("accept", consumer.getClass());
410 }
411 if (dispatchMode == DispatchMode.SEQ) {
412 dispatchLambdaKernelSeq(ary, consumer);
413 } else if (dispatchMode == DispatchMode.OKRA) {
414 dispatchLambdaKernelOkra(ary, consumer);
415 }
416 }
417
418 private ArrayList<String> getLambdaMethodNames() {
419 Class<?> clazz = this.getClass();
420 ArrayList<String> lambdaNames = new ArrayList<>();
421 while (clazz != null && (lambdaNames.size() == 0)) {
422 for (Method m : clazz.getDeclaredMethods()) {
423 logger.fine(" in " + clazz + ", trying to match " + m);
424 if (m.getName().startsWith("lambda$")) {
425 lambdaNames.add(m.getName());
426 }
427 }
428 // Didn't find it in current clazz, try superclass.
429 clazz = clazz.getSuperclass();
430 }
431 return lambdaNames;
432 }
433
434 /**
435 * findLambdaMethodName finds a name in the class starting with lambda$. If we find more than
436 * one, throw an error, and tell user to override explicitly
437 */
438 private String findLambdaMethodName() {
439 // If user overrode getTestMethodName, use that name.
440 if (!getTestMethodName().equals(defaultMethodName)) {
441 return getTestMethodName();
442 } else {
443 ArrayList<String> lambdaNames = getLambdaMethodNames();
444 switch (lambdaNames.size()) {
445 case 1:
446 return lambdaNames.get(0);
447 case 0:
448 fail("No lambda method found in " + this.getClass());
449 return null;
450 default:
451 // More than one lambda.
452 String msg = "Multiple lambda methods found in " + this.getClass() + "\nYou should override getTestMethodName with one of the following\n";
453 for (String name : lambdaNames) {
454 msg = msg + name + "\n";
455 }
456 fail(msg);
457 return null;
458 }
459 }
460 }
461
462 /**
463 * The getCompiledHSAILSource returns the string of HSAIL code for the compiled method. By
464 * default, throws an error. In graal for instance, this would be overridden in
465 * GraalKernelTester.
466 */
467 public String getCompiledHSAILSource(Method testMethod1) {
468 fail("no compiler connected so unable to compile " + testMethod1 + "\nYou could try injecting HSAIL or OpenCL");
469 return null;
470 }
471
472 public String getHSAILSource(Method testMethod1) {
473 switch (hsailMode) {
474 case COMPILED:
475 return getCompiledHSAILSource(testMethod1);
476 case INJECT_HSAIL:
477 return getHsailFromClassnameHsailFile();
478 case INJECT_OCL:
479 return getHsailFromClassnameOclFile();
480 default:
481 fail("unknown hsailMode = " + hsailMode);
482 return null;
483 }
484 }
485
486 /**
487 * The getHSAILKernelName returns the name of the hsail kernel. By default we use 'run'. unless
488 * coming from opencl injection. Could be overridden by the junit test.
489 */
490 public String getHSAILKernelName() {
491 return (hsailMode != HsailMode.INJECT_OCL ? "&run" : "&__OpenCL_run_kernel");
492 }
493
494 private void createOkraKernel() {
495 // Call routines in the derived class to get the hsail code and kernel name.
496 String hsailSource = getHSAILSource(testMethod);
497 if (!okraLibExists) {
498 if (!gaveNoOkraWarning) {
499 logger.fine("No Okra library detected, skipping all KernelTester tests in " + this.getClass().getPackage().getName());
500 gaveNoOkraWarning = true;
501 }
502 }
503 // Ignore any kerneltester test if okra does not exist.
504 assumeTrue(okraLibExists);
505 // Control which okra instances can run the tests.
506 onSimulator = OkraContext.isSimulator();
507 okraContext = new OkraContext();
508 if (!okraContext.isValid()) {
509 fail("...unable to create context");
510 }
511 // Control verbosity in okra from our logLevel.
512 if (logLevel.intValue() <= Level.INFO.intValue()) {
513 okraContext.setVerbose(true);
514 }
515 okraKernel = new OkraKernel(okraContext, hsailSource, getHSAILKernelName());
516 if (!okraKernel.isValid()) {
517 fail("...unable to create kernel");
518 }
519 }
520
521 private void dispatchKernelOkra(int range, Object... args) {
522 if (okraKernel == null) {
523 createOkraKernel();
524 }
525 if (logLevel.intValue() <= Level.FINE.intValue()) {
526 logger.fine("Arguments passed to okra...");
527 for (Object arg : args) {
528 logger.fine(" " + arg);
529 }
530 }
531 okraKernel.setLaunchAttributes(range);
532 okraKernel.dispatchWithArgs(args);
533 }
534
535 private void dispatchMethodKernelSeq(int range, Object... args) {
536 Object[] invokeArgs = new Object[args.length + 1];
537 // Need space on the end for the gid parameter.
538 System.arraycopy(args, 0, invokeArgs, 0, args.length);
539 int gidArgIndex = invokeArgs.length - 1;
540 if (logLevel.intValue() <= Level.FINE.intValue()) {
541 for (Object arg : args) {
542 logger.fine(arg.toString());
543 }
544 }
545 for (int rangeIndex = 0; rangeIndex < range; rangeIndex++) {
546 invokeArgs[gidArgIndex] = rangeIndex;
547 try {
548 testMethod.invoke(this, invokeArgs);
549 } catch (IllegalAccessException e) {
550 fail("could not invoke " + testMethod + ", make sure it is public");
551 } catch (IllegalArgumentException e) {
552 fail("wrong arguments invoking " + testMethod + ", check number and type of args passed to dispatchMethodKernel");
553 } catch (InvocationTargetException e) {
554 Throwable cause = e.getCause();
555 /**
556 * We will ignore ArrayIndexOutOfBoundsException because the graal okra target
557 * doesn't really handle it yet (basically returns early if it sees one).
558 */
559 if (cause instanceof ArrayIndexOutOfBoundsException) {
560 logger.severe("ignoring ArrayIndexOutOfBoundsException for index " + rangeIndex);
561 } else {
562 // Other exceptions.
563 String errstr = testMethod + " threw an exception on gid=" + rangeIndex + ", exception was " + cause;
564 fail(errstr);
565 }
566 } catch (Exception e) {
567 fail("Unknown exception " + e + " invoking " + testMethod);
568 }
569 }
570 }
571
572 private void dispatchMethodKernelOkra(int range, Object... args) {
573 Object[] fixedArgs = fixArgTypes(args);
574 if (Modifier.isStatic(testMethod.getModifiers())) {
575 dispatchKernelOkra(range, fixedArgs);
576 } else {
577 // If it is a non-static method we have to push "this" as the first argument.
578 Object[] newFixedArgs = new Object[fixedArgs.length + 1];
579 System.arraycopy(fixedArgs, 0, newFixedArgs, 1, fixedArgs.length);
580 newFixedArgs[0] = this;
581 dispatchKernelOkra(range, newFixedArgs);
582 }
583 }
584
585 /**
586 * For primitive arg parameters, make sure arg types are cast to whatever the testMethod
587 * signature says they should be.
588 */
589 private Object[] fixArgTypes(Object[] args) {
590 Object[] fixedArgs = new Object[args.length];
591 for (int i = 0; i < args.length; i++) {
592 Class<?> paramClass = testMethodParams[i];
593 if (paramClass.equals(Float.class) || paramClass.equals(float.class)) {
594 fixedArgs[i] = ((Number) args[i]).floatValue();
595 } else if (paramClass.equals(Integer.class) || paramClass.equals(int.class)) {
596 fixedArgs[i] = ((Number) args[i]).intValue();
597 } else if (paramClass.equals(Long.class) || paramClass.equals(long.class)) {
598 fixedArgs[i] = ((Number) args[i]).longValue();
599 } else if (paramClass.equals(Double.class) || paramClass.equals(double.class)) {
600 fixedArgs[i] = ((Number) args[i]).doubleValue();
601 } else if (paramClass.equals(Byte.class) || paramClass.equals(byte.class)) {
602 fixedArgs[i] = ((Number) args[i]).byteValue();
603 } else if (paramClass.equals(Boolean.class) || paramClass.equals(boolean.class)) {
604 fixedArgs[i] = (boolean) args[i];
605 } else {
606 // All others just move unchanged.
607 fixedArgs[i] = args[i];
608 }
609 }
610 return fixedArgs;
611 }
612
613 /**
614 * Dispatching a lambda on the java side is simple.
615 */
616 @SuppressWarnings("static-method")
617 private void dispatchLambdaKernelSeq(int range, MyIntConsumer consumer) {
618 for (int i = 0; i < range; i++) {
619 consumer.accept(i);
620 }
621 }
622
623 @SuppressWarnings("static-method")
624 private void dispatchLambdaKernelSeq(Object[] ary, MyObjConsumer consumer) {
625 for (Object obj : ary) {
626 consumer.accept(obj);
627 }
628 }
629
630 /**
631 * The dispatchLambdaMethodKernelOkra dispatches in the case where the hsail kernel implements
632 * the lambda method itself as opposed to the wrapper that calls the lambda method. From the
633 * consumer object, we need to find the fields and pass them to the kernel.
634 */
635 private void dispatchLambdaMethodKernelOkra(int range, MyIntConsumer consumer) {
636 logger.info("To determine parameters to pass to hsail kernel, we will examine " + consumer.getClass());
637 Field[] fields = consumer.getClass().getDeclaredFields();
638 Object[] args = new Object[fields.length];
639 int argIndex = 0;
640 for (Field f : fields) {
641 logger.info("... " + f);
642 args[argIndex++] = getFieldFromObject(f, consumer);
643 }
644 dispatchKernelOkra(range, args);
645 }
646
647 private void dispatchLambdaMethodKernelOkra(Object[] ary, MyObjConsumer consumer) {
648 logger.info("To determine parameters to pass to hsail kernel, we will examine " + consumer.getClass());
649 Field[] fields = consumer.getClass().getDeclaredFields();
650 Object[] args = new Object[fields.length];
651 int argIndex = 0;
652 for (Field f : fields) {
653 logger.info("... " + f);
654 args[argIndex++] = getFieldFromObject(f, consumer);
655 }
656 dispatchKernelOkra(ary.length, args);
657 }
658
659 /**
660 * The dispatchLambdaKernelOkra dispatches in the case where the hsail kernel where the hsail
661 * kernel implements the accept method of the wrapper that calls the lambda method as opposed to
662 * the actual lambda method itself.
663 */
664 private void dispatchLambdaKernelOkra(int range, MyIntConsumer consumer) {
665 // The "wrapper" method always has only one arg consisting of the consumer.
666 Object[] args = new Object[1];
667 args[0] = consumer;
668 dispatchKernelOkra(range, args);
669 }
670
671 private void dispatchLambdaKernelOkra(Object[] ary, MyObjConsumer consumer) {
672 // The "wrapper" method always has only one arg consisting of the consumer.
673 Object[] args = new Object[1];
674 args[0] = consumer;
675 dispatchKernelOkra(ary.length, args);
676 }
677
678 private void disposeKernelOkra() {
679 if (okraContext != null) {
680 okraContext.dispose();
681 }
682 }
683
684 private void compareOkraToSeq(HsailMode hsailMode1) {
685 compareOkraToSeq(hsailMode1, false);
686 }
687
688 /**
689 * Runs this instance on OKRA, and as SEQ and compares the output of the two executions.
690 */
691 private void compareOkraToSeq(HsailMode hsailMode1, boolean useLambda) {
692 // Create and run sequential instance.
693 KernelTester testerSeq = newInstance();
694 testerSeq.setDispatchMode(DispatchMode.SEQ);
695 testerSeq.runTest();
696 // Now do this object.
697 this.setHsailMode(hsailMode1);
698 this.setDispatchMode(DispatchMode.OKRA);
699 this.useLambdaMethod = useLambda;
700 this.runTest();
701 this.disposeKernelOkra();
702 assertTrue("failed comparison to SEQ", compareResults(testerSeq));
703 }
704
705 public void testGeneratedHsail() {
706 compareOkraToSeq(HsailMode.COMPILED);
707 }
708
709 public void testGeneratedHsailUsingLambdaMethod() {
710 compareOkraToSeq(HsailMode.COMPILED, true);
711 }
712
713 public void testInjectedHsail() {
714 newInstance().compareOkraToSeq(HsailMode.INJECT_HSAIL);
715 }
716
717 public void testInjectedOpencl() {
718 newInstance().compareOkraToSeq(HsailMode.INJECT_OCL);
719 }
720
721 private static Object getFieldFromObject(Field f, Object fromObj) {
722 try {
723 f.setAccessible(true);
724 Type type = f.getType();
725 logger.info("type = " + type);
726 if (type == double.class) {
727 return f.getDouble(fromObj);
728 } else if (type == float.class) {
729 return f.getFloat(fromObj);
730 } else if (type == long.class) {
731 return f.getLong(fromObj);
732 } else if (type == int.class) {
733 return f.getInt(fromObj);
734 } else if (type == byte.class) {
735 return f.getByte(fromObj);
736 } else if (type == boolean.class) {
737 return f.getBoolean(fromObj);
738 } else {
739 return f.get(fromObj);
740 }
741 } catch (Exception e) {
742 fail("unable to get field " + f + " from " + fromObj);
743 return null;
744 }
745 }
746
747 public static void checkFileExists(String fileName) {
748 assertTrue(fileName + " does not exist", fileExists(fileName));
749 }
750
751 public static boolean fileExists(String fileName) {
752 return new File(fileName).exists();
753 }
754
755 public static String getFileAsString(String sourceFileName) {
756 String source = null;
757 try {
758 checkFileExists(sourceFileName);
759 source = new String(Files.readAllBytes(FileSystems.getDefault().getPath(sourceFileName)));
760 } catch (IOException e) {
761 fail("could not open file " + sourceFileName);
762 return null;
763 }
764 return source;
765 }
766
767 public static String getHsailFromFile(String sourceFileName) {
768 logger.severe("... getting hsail from file " + sourceFileName);
769 return getFileAsString(sourceFileName);
770 }
771
772 private static void executeCmd(String... cmd) {
773 logger.info("spawning" + Arrays.toString(cmd));
774 try {
775 ProcessBuilder pb = new ProcessBuilder(cmd);
776 Process p = pb.start();
777 if (logLevel.intValue() <= Level.INFO.intValue()) {
778 InputStream in = p.getInputStream();
779 BufferedInputStream buf = new BufferedInputStream(in);
780 InputStreamReader inread = new InputStreamReader(buf);
781 BufferedReader bufferedreader = new BufferedReader(inread);
782 String line;
783 while ((line = bufferedreader.readLine()) != null) {
784 logger.info(line);
785 }
786 }
787 p.waitFor();
788 } catch (Exception e) {
789 fail("could not execute <" + Arrays.toString(cmd) + ">");
790 }
791 }
792
793 public static String getHsailFromOpenCLFile(String openclFileName) {
794 String openclHsailFile = "opencl_out.hsail";
795 String tmpTahitiFile = "_temp_0_Tahiti.txt";
796 checkFileExists(openclFileName);
797 logger.severe("...converting " + openclFileName + " to HSAIL...");
798 executeCmd("aoc2", "-m64", "-I./", "-march=hsail", openclFileName);
799 if (fileExists(tmpTahitiFile)) {
800 return getFileAsString(tmpTahitiFile);
801 } else {
802 executeCmd("HSAILasm", "-disassemble", "-o", openclHsailFile, openclFileName.replace(".cl", ".bin"));
803 checkFileExists(openclHsailFile);
804 return getFileAsString(openclHsailFile);
805 }
806 }
807
808 public String getHsailFromClassnameHsailFile() {
809 return (getHsailFromFile(this.getClass().getSimpleName() + ".hsail"));
810 }
811
812 public String getHsailFromClassnameOclFile() {
813 return (getHsailFromOpenCLFile(this.getClass().getSimpleName() + ".cl"));
814 }
815
816 public static void logInfo(String msg) {
817 logger.info(msg);
818 }
819
820 public static void logSevere(String msg) {
821 logger.severe(msg);
822 }
823
824 public static void setLogLevel(Level level) {
825 logLevel = level;
826 logger.setLevel(level);
827 consoleHandler.setLevel(level);
828 }
829 }
--- EOF ---