1 package valhalla.vector;
   2 
   3 import jdk.experimental.value.ValueType;
   4 import jdk.internal.org.objectweb.asm.ClassWriter;
   5 
   6 import java.lang.invoke.MethodHandle;
   7 import java.lang.invoke.MethodHandles;
   8 
   9 import static java.lang.invoke.MethodType.methodType;
  10 import static valhalla.vector.Utils.assertEquals;
  11 
  12 public class VectorTest {
  13     static final ValueType<?> VT = ValueType.forClass(Long2.class);
  14 
  15     static /* QLong2[] */ Object initL2Array(int length) {
  16         try {
  17             Object arr = MethodHandles.arrayConstructor(VT.arrayValueClass()).invoke(length);
  18             for (int i = 0; i < length; i++) {
  19                 Long2 v = new Long2(2 * i, 2 * i + 1);
  20                 MethodHandles.arrayElementSetter(VT.arrayValueClass()).invoke(arr, i, v);
  21             }
  22             return arr;
  23         } catch (Throwable e) {
  24             throw new Error(e);
  25         }
  26     }
  27 
  28     static final MethodHandle SUM_ARRAY_L2 =
  29             VectorUtils.reducerLoop(VT, VectorLibrary.L2.ADD_L, VectorLibrary.L2.HADD_L, VT.defaultValueConstant()).
  30                     asType(methodType(long.class, Object.class));
  31 
  32     /**
  33      * long sum(QLong2[] a) {
  34      *   QLong2 v = QLong2.default; // (0,0)
  35      *   for (int i = 0; i < a.length; i++) {
  36      *     v = QLong2(v.lo + a[i].lo, v.hi + a[i].hi);
  37      *   }
  38      *   return v.lo + v.hi;
  39      * }
  40      */
  41     // @DontInline
  42     static long sumArrayL2(Object arr) {
  43         try {
  44             return (long) SUM_ARRAY_L2.invokeExact(arr);
  45         } catch (Throwable e) {
  46             throw new Error(e);
  47         }
  48     }
  49 
  50     static void testSumArray(int size) {
  51         Object arr = initL2Array(size); // QLong2[size]
  52         long expected = size * (2*size - 1);
  53         for (int i = 0; i < 20_000; i++) {
  54             long sum = sumArrayL2(arr);
  55             assertEquals(expected, sum);
  56         }
  57     }
  58 
  59     /*========================================================*/
  60 
  61     static MethodHandle createConditional() {
  62         // T target(A...,B...);
  63         // T fallback(A...,B...);
  64         // T adapter(A... a,B... b) {
  65         //   if (test(a...))
  66         //     return target(a..., b...);
  67         //   else
  68         //     return fallback(a..., b...);
  69         // }
  70         MethodHandle test = MethodHandles.identity(boolean.class);
  71 
  72         MethodHandle add = VectorLibrary.L2.ADD_L;
  73         MethodHandle addVL = MethodHandles.filterArguments(add, 1, VT.unbox());
  74         MethodHandle inc = MethodHandles.insertArguments(addVL, 1, new Long2(1L, 1L));
  75         MethodHandle incZ = MethodHandles.dropArguments(inc, 0, boolean.class);
  76 
  77         MethodHandle idZ = MethodHandles.dropArguments(VT.identity(), 0, boolean.class);
  78 
  79         // (boolean, QLong2)QLong2
  80         MethodHandle gwt = MethodHandles.guardWithTest(test, incZ, idZ);
  81         return gwt;
  82     }
  83 
  84     static final MethodHandle conditionalMH = createConditional().
  85             asType(methodType(Long2.class, boolean.class, Long2.class));
  86 
  87     // @DontInline
  88     static Long2 conditional(boolean b, Long2 v) {
  89         try {
  90             return (Long2) conditionalMH.invokeExact(b, v);
  91         } catch (Throwable e) {
  92             throw new Error(e);
  93         }
  94     }
  95 
  96     static void testConditional() {
  97         Long2 v = new Long2(1L, 2L);
  98         for (int i = 0; i < 20_000; i++) {
  99             conditional(true,  v);
 100             conditional(false, v);
 101         }
 102     }
 103 
 104     /*========================================================*/
 105 
 106     public static void main(String[] args) {
 107         if (args.length == 0) {
 108             args = new String[] { "1", "5", "10", "0"};
 109         }
 110         for (String arg : args) {
 111             int size = Integer.parseInt(arg);
 112             testSumArray(size);
 113         }
 114         testConditional();
 115     }
 116 }