1 package valhalla.vector;
   2 
   3 import jdk.experimental.value.ValueType;
   4 
   5 import java.lang.invoke.MethodHandle;
   6 import java.lang.invoke.MethodHandles;
   7 import java.lang.invoke.MethodType;
   8 import java.lang.reflect.Field;
   9 import java.util.Arrays;
  10 import java.util.stream.IntStream;
  11 
  12 import static java.lang.invoke.MethodType.methodType;
  13 import static valhalla.vector.Utils.compute;
  14 import static valhalla.vector.Utils.valueFields;
  15 
  16 public class VectorUtils {
  17     private static final Class<?> THIS_KLASS = VectorUtils.class;
  18     private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
  19 
  20     /**
  21      * liftedOp(VT v1,v2,...,va) =
  22      *   let opFi(v,w) = op(v1.fi,v2.fi,...,va.fi),
  23      *       factory(f1,...,fn) = zeroVT().withF1(f1)...withFn(fn) in
  24      *   factory(opF1(v1,v2,...,va),...,opFn(v1,v2,...,va))
  25      */
  26     static MethodHandle lift(ValueType<?> vt, MethodHandle op, MethodHandle factory) {
  27         Class<?> elementType = op.type().returnType();
  28         int arity = op.type().parameterCount();
  29         int count = 0;
  30         MethodHandle liftedOp = factory;
  31         for (Field f : valueFields(vt)) {
  32             MethodHandle getter = compute(() -> vt.findGetter(LOOKUP, f.getName(), elementType));
  33             MethodHandle[] getters = new MethodHandle[arity];
  34             Arrays.fill(getters, getter);
  35             MethodHandle fieldOp = MethodHandles.filterArguments(op, 0, getters);
  36             liftedOp = MethodHandles.collectArguments(liftedOp, arity*count, fieldOp);
  37             count++;
  38         }
  39         Class<?> dvt = vt.valueClass();
  40         MethodType liftedOpType = naryType(arity, dvt); // (dvt,...,dvt)dvt
  41         int[] reorder = IntStream.range(0, arity*count).map(i -> i % arity).toArray(); // [0, 1, ..., arity, ..., 0, 1, ..., arity]
  42         liftedOp = MethodHandles.permuteArguments(liftedOp, liftedOpType, reorder);
  43         return liftedOp;
  44     }
  45 
  46     /**
  47      * T reducer(VT v) = op(...op(op(zero, v.f1), v.f2), ..., v.fn)
  48      */
  49     static MethodHandle reducer(ValueType<?> vt, MethodHandle op, MethodHandle zero) {
  50         Class<?> elementType = op.type().returnType();
  51         MethodHandle reducer = zero;
  52         Field[] valueFields = valueFields(vt);
  53         for (Field f : valueFields) {
  54             MethodHandle getter = compute(() -> vt.findGetter(LOOKUP, f.getName(), elementType));
  55             MethodHandle reduceOp = MethodHandles.filterArguments(op, 1, getter);
  56             reducer = MethodHandles.collectArguments(reduceOp, 0, reducer); // (x, ...) -> factory(...).with(x));
  57         }
  58         int[] reorder = new int[valueFields.length]; // [0, ..., 0]
  59         reducer = MethodHandles.permuteArguments(reducer, methodType(long.class, vt.valueClass()), reorder);
  60         return reducer;
  61     }
  62 
  63     /**
  64      * T reducerLoop(DVT[] a) {
  65      *   DVT v = zero();
  66      *   for (int i = 0; i < a.length; i++) {
  67      *     v = op(v, a[i]);
  68      *   }
  69      *   return reducer(v);
  70      * }
  71      */
  72     public static MethodHandle reducerLoop(ValueType<?> vt, MethodHandle op, MethodHandle reducer, MethodHandle zero) {
  73         MethodHandle iterations = MethodHandles.arrayLength(vt.arrayValueClass());        // (DVT[]) int
  74         MethodHandle init = MethodHandles.dropArguments(zero, 0, vt.arrayValueClass());   // (DVT[]) DVT
  75 
  76         // DVT body(DVT v, int i, DVT[] arr) { return op(v, arr[i])); }
  77         MethodHandle getElement = MethodHandles.arrayElementGetter(vt.arrayValueClass()); // (DVT[], int) DVT
  78         MethodHandle body = MethodHandles.permuteArguments( // (0:DVT, 2:int, 1:DVT[]) DVT
  79                 MethodHandles.collectArguments(op, 1, getElement), // (DVT, DVT[], int) DVT
  80                 methodType(vt.valueClass(), vt.valueClass(), int.class, vt.arrayValueClass()),
  81                 0, 2, 1);
  82 
  83         // A = [ DVT[] ], V = [ DVT ]
  84         //
  85         // int iterations(A...);
  86         // V init(A...);
  87         // V body(V, int, A...);
  88         //
  89         // V countedLoop(A...) {
  90         //   int end = iterations(A...);
  91         //   V v = init(A...);
  92         //   for (int i = 0; i < end; ++i) {
  93         //     v = body(v, i, A...);
  94         //   }
  95         //   return v;
  96         // }
  97         MethodHandle loop = MethodHandles.countedLoop(iterations, init, body); // (DVT[])DVT
  98 
  99         // Debug
 100 //        loop = MethodHandles.filterReturnValue(loop,
 101 //                compute(() -> LOOKUP.findStatic(THIS_KLASS, "id", methodType(vt.boxClass(), vt.boxClass()))))
 102 //                      .asType(methodType(vt.valueClass(), vt.valueClass()));
 103 
 104         //MH(DVT[])DVT => MH(DVT[])T
 105         MethodHandle result = MethodHandles.filterReturnValue(loop, reducer);
 106         return result;
 107     }
 108 
 109     static MethodType naryType(int arity, Class<?> c) {
 110         Class<?>[] parameterTypes = new Class<?>[arity];
 111         Arrays.fill(parameterTypes, c);
 112         return methodType(c, parameterTypes);
 113     }
 114 
 115     static MethodHandle valueFactory(Class<?> vcc, MethodHandles.Lookup lookup) {
 116         ValueType<?> vt = ValueType.forClass(vcc);
 117         try {
 118             MethodHandle factory = vt.defaultValueConstant(); // empty
 119             for (Field f : valueFields(vt)) {
 120                 MethodHandle wither = vt.findWither(lookup, f.getName(), f.getType());
 121                 factory = MethodHandles.collectArguments(wither, 0, factory); // (x, ...) -> factory(...).with(x));
 122             }
 123             return factory;
 124         } catch (Throwable e) {
 125             throw new Error(e);
 126         }
 127     }
 128 
 129     static Long2 id(Long2 v) {
 130         return v;
 131     }
 132 }