1 package valhalla.vector; 2 3 import jdk.experimental.value.ValueType; 4 5 import java.lang.invoke.MethodHandle; 6 import java.lang.invoke.MethodHandles; 7 import java.lang.invoke.MethodType; 8 import java.lang.reflect.Field; 9 import java.util.Arrays; 10 import java.util.stream.IntStream; 11 12 import static java.lang.invoke.MethodType.methodType; 13 import static valhalla.vector.Utils.compute; 14 import static valhalla.vector.Utils.valueFields; 15 16 public class VectorUtils { 17 private static final Class<?> THIS_KLASS = VectorUtils.class; 18 private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); 19 20 /** 21 * liftedOp(VT v1,v2,...,va) = 22 * let opFi(v,w) = op(v1.fi,v2.fi,...,va.fi), 23 * factory(f1,...,fn) = zeroVT().withF1(f1)...withFn(fn) in 24 * factory(opF1(v1,v2,...,va),...,opFn(v1,v2,...,va)) 25 */ 26 static MethodHandle lift(ValueType<?> vt, MethodHandle op, MethodHandle factory) { 27 Class<?> elementType = op.type().returnType(); 28 int arity = op.type().parameterCount(); 29 int count = 0; 30 MethodHandle liftedOp = factory; 31 for (Field f : valueFields(vt)) { 32 MethodHandle getter = compute(() -> vt.findGetter(LOOKUP, f.getName(), elementType)); 33 MethodHandle[] getters = new MethodHandle[arity]; 34 Arrays.fill(getters, getter); 35 MethodHandle fieldOp = MethodHandles.filterArguments(op, 0, getters); 36 liftedOp = MethodHandles.collectArguments(liftedOp, arity*count, fieldOp); 37 count++; 38 } 39 Class<?> dvt = vt.valueClass(); 40 MethodType liftedOpType = naryType(arity, dvt); // (dvt,...,dvt)dvt 41 int[] reorder = IntStream.range(0, arity*count).map(i -> i % arity).toArray(); // [0, 1, ..., arity, ..., 0, 1, ..., arity] 42 liftedOp = MethodHandles.permuteArguments(liftedOp, liftedOpType, reorder); 43 return liftedOp; 44 } 45 46 /** 47 * T reducer(VT v) = op(...op(op(zero, v.f1), v.f2), ..., v.fn) 48 */ 49 static MethodHandle reducer(ValueType<?> vt, MethodHandle op, MethodHandle zero) { 50 Class<?> elementType = op.type().returnType(); 51 MethodHandle reducer = zero; 52 Field[] valueFields = valueFields(vt); 53 for (Field f : valueFields) { 54 MethodHandle getter = compute(() -> vt.findGetter(LOOKUP, f.getName(), elementType)); 55 MethodHandle reduceOp = MethodHandles.filterArguments(op, 1, getter); 56 reducer = MethodHandles.collectArguments(reduceOp, 0, reducer); // (x, ...) -> factory(...).with(x)); 57 } 58 int[] reorder = new int[valueFields.length]; // [0, ..., 0] 59 reducer = MethodHandles.permuteArguments(reducer, methodType(long.class, vt.valueClass()), reorder); 60 return reducer; 61 } 62 63 /** 64 * T reducerLoop(DVT[] a) { 65 * DVT v = zero(); 66 * for (int i = 0; i < a.length; i++) { 67 * v = op(v, a[i]); 68 * } 69 * return reducer(v); 70 * } 71 */ 72 public static MethodHandle reducerLoop(ValueType<?> vt, MethodHandle op, MethodHandle reducer, MethodHandle zero) { 73 MethodHandle iterations = MethodHandles.arrayLength(vt.arrayValueClass()); // (DVT[]) int 74 MethodHandle init = MethodHandles.dropArguments(zero, 0, vt.arrayValueClass()); // (DVT[]) DVT 75 76 // DVT body(DVT v, int i, DVT[] arr) { return op(v, arr[i])); } 77 MethodHandle getElement = MethodHandles.arrayElementGetter(vt.arrayValueClass()); // (DVT[], int) DVT 78 MethodHandle body = MethodHandles.permuteArguments( // (0:DVT, 2:int, 1:DVT[]) DVT 79 MethodHandles.collectArguments(op, 1, getElement), // (DVT, DVT[], int) DVT 80 methodType(vt.valueClass(), vt.valueClass(), int.class, vt.arrayValueClass()), 81 0, 2, 1); 82 83 // A = [ DVT[] ], V = [ DVT ] 84 // 85 // int iterations(A...); 86 // V init(A...); 87 // V body(V, int, A...); 88 // 89 // V countedLoop(A...) { 90 // int end = iterations(A...); 91 // V v = init(A...); 92 // for (int i = 0; i < end; ++i) { 93 // v = body(v, i, A...); 94 // } 95 // return v; 96 // } 97 MethodHandle loop = MethodHandles.countedLoop(iterations, init, body); // (DVT[])DVT 98 99 // Debug 100 // loop = MethodHandles.filterReturnValue(loop, 101 // compute(() -> LOOKUP.findStatic(THIS_KLASS, "id", methodType(vt.boxClass(), vt.boxClass())))) 102 // .asType(methodType(vt.valueClass(), vt.valueClass())); 103 104 //MH(DVT[])DVT => MH(DVT[])T 105 MethodHandle result = MethodHandles.filterReturnValue(loop, reducer); 106 return result; 107 } 108 109 static MethodType naryType(int arity, Class<?> c) { 110 Class<?>[] parameterTypes = new Class<?>[arity]; 111 Arrays.fill(parameterTypes, c); 112 return methodType(c, parameterTypes); 113 } 114 115 static MethodHandle valueFactory(Class<?> vcc, MethodHandles.Lookup lookup) { 116 ValueType<?> vt = ValueType.forClass(vcc); 117 try { 118 MethodHandle factory = vt.defaultValueConstant(); // empty 119 for (Field f : valueFields(vt)) { 120 MethodHandle wither = vt.findWither(lookup, f.getName(), f.getType()); 121 factory = MethodHandles.collectArguments(wither, 0, factory); // (x, ...) -> factory(...).with(x)); 122 } 123 return factory; 124 } catch (Throwable e) { 125 throw new Error(e); 126 } 127 } 128 129 static Long2 id(Long2 v) { 130 return v; 131 } 132 }