1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 import jdk.experimental.value.ValueType;
  27 
  28 import java.lang.invoke.MethodHandle;
  29 import java.lang.invoke.MethodHandles;
  30 import java.lang.invoke.MethodType;
  31 import java.lang.reflect.Field;
  32 import java.util.Arrays;
  33 import java.util.stream.IntStream;
  34 
  35 import static java.lang.invoke.MethodType.methodType;
  36 
  37 public class VectorUtils {
  38     private static final Class<?> THIS_KLASS = VectorUtils.class;
  39     private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
  40 
  41     /**
  42      * liftedOp(VT v1,v2,...,va) =
  43      *   let opFi(v,w) = op(v1.fi,v2.fi,...,va.fi),
  44      *       factory(f1,...,fn) = zeroVT().withF1(f1)...withFn(fn) in
  45      *   factory(opF1(v1,v2,...,va),...,opFn(v1,v2,...,va))
  46      */
  47     static MethodHandle lift(ValueType<?> vt, MethodHandle op, MethodHandle factory) {
  48         Class<?> elementType = op.type().returnType();
  49         int arity = op.type().parameterCount();
  50         int count = 0;
  51         MethodHandle liftedOp = factory;
  52         for (Field f : vt.valueFields()) {
  53             MethodHandle getter = Utils.compute(() -> vt.findGetter(LOOKUP, f.getName(), elementType));
  54             MethodHandle[] getters = new MethodHandle[arity];
  55             Arrays.fill(getters, getter);
  56             MethodHandle fieldOp = MethodHandles.filterArguments(op, 0, getters);
  57             liftedOp = MethodHandles.collectArguments(liftedOp, arity*count, fieldOp);
  58             count++;
  59         }
  60         Class<?> dvt = vt.valueClass();
  61         MethodType liftedOpType = naryType(arity, dvt); // (dvt,...,dvt)dvt
  62         int[] reorder = IntStream.range(0, arity*count).map(i -> i % arity).toArray(); // [0, 1, ..., arity, ..., 0, 1, ..., arity]
  63         liftedOp = MethodHandles.permuteArguments(liftedOp, liftedOpType, reorder);
  64         return liftedOp;
  65     }
  66 
  67     /**
  68      * T reducer(VT v) = op(...op(op(zero, v.f1), v.f2), ..., v.fn)
  69      */
  70     static MethodHandle reducer(ValueType<?> vt, MethodHandle op, MethodHandle zero) {
  71         Class<?> elementType = op.type().returnType();
  72         MethodHandle reducer = zero;
  73         Field[] valueFields = vt.valueFields();
  74         for (Field f : valueFields) {
  75             MethodHandle getter = Utils.compute(() -> vt.findGetter(LOOKUP, f.getName(), elementType));
  76             MethodHandle reduceOp = MethodHandles.filterArguments(op, 1, getter);
  77             reducer = MethodHandles.collectArguments(reduceOp, 0, reducer); // (x, ...) -> factory(...).with(x));
  78         }
  79         int[] reorder = new int[valueFields.length]; // [0, ..., 0]
  80         reducer = MethodHandles.permuteArguments(reducer, methodType(long.class, vt.valueClass()), reorder);
  81         return reducer;
  82     }
  83 
  84     /**
  85      * T reducerLoop(DVT[] a) {
  86      *   DVT v = zero();
  87      *   for (int i = 0; i < a.length; i++) {
  88      *     v = op(v, a[i]);
  89      *   }
  90      *   return reducer(v);
  91      * }
  92      */
  93     public static MethodHandle reducerLoop(ValueType<?> vt, MethodHandle op, MethodHandle reducer, MethodHandle zero) {
  94         MethodHandle iterations = MethodHandles.arrayLength(vt.arrayValueClass());        // (DVT[]) int
  95         MethodHandle init = MethodHandles.dropArguments(zero, 0, vt.arrayValueClass());   // (DVT[]) DVT
  96 
  97         // DVT body(DVT v, int i, DVT[] arr) { return op(v, arr[i])); }
  98         MethodHandle getElement = MethodHandles.arrayElementGetter(vt.arrayValueClass()); // (DVT[], int) DVT
  99         MethodHandle body = MethodHandles.permuteArguments( // (0:DVT, 2:int, 1:DVT[]) DVT
 100                 MethodHandles.collectArguments(op, 1, getElement), // (DVT, DVT[], int) DVT
 101                 methodType(vt.valueClass(), vt.valueClass(), int.class, vt.arrayValueClass()),
 102                 0, 2, 1);
 103 
 104         // A = [ DVT[] ], V = [ DVT ]
 105         //
 106         // int iterations(A...);
 107         // V init(A...);
 108         // V body(V, int, A...);
 109         //
 110         // V countedLoop(A...) {
 111         //   int end = iterations(A...);
 112         //   V v = init(A...);
 113         //   for (int i = 0; i < end; ++i) {
 114         //     v = body(v, i, A...);
 115         //   }
 116         //   return v;
 117         // }
 118         MethodHandle loop = MethodHandles.countedLoop(iterations, init, body); // (DVT[])DVT
 119 
 120         //MH(DVT[])DVT => MH(DVT[])T
 121         MethodHandle result = MethodHandles.filterReturnValue(loop, reducer);
 122         return result;
 123     }
 124 
 125     static MethodType naryType(int arity, Class<?> c) {
 126         Class<?>[] parameterTypes = new Class<?>[arity];
 127         Arrays.fill(parameterTypes, c);
 128         return methodType(c, parameterTypes);
 129     }
 130 
 131     static MethodHandle valueFactory(Class<?> vcc, MethodHandles.Lookup lookup) {
 132         ValueType<?> vt = ValueType.forClass(vcc);
 133         return Utils.compute(() -> vt.unreflectWithers(lookup, true, vt.valueFields()));
 134     }
 135 }
 136