1 /*
   2  * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 package jdk.internal.foreign.abi;
  25 
  26 import java.foreign.NativeMethodType;
  27 import java.foreign.memory.Callback;
  28 import java.foreign.memory.LayoutType;
  29 import java.foreign.memory.Pointer;
  30 import java.foreign.memory.Struct;
  31 import java.lang.invoke.MethodHandle;
  32 import java.lang.invoke.MethodHandles;
  33 import java.lang.invoke.MethodType;
  34 import java.util.ArrayList;
  35 import java.util.Collections;
  36 import java.util.List;
  37 import java.util.function.UnaryOperator;
  38 import java.util.stream.Collectors;
  39 import java.util.stream.Stream;
  40 import jdk.internal.foreign.Util;
  41 import jdk.internal.foreign.memory.BoundedPointer;
  42 import jdk.internal.foreign.memory.CallbackImpl;
  43 import jdk.internal.foreign.memory.LayoutTypeImpl;
  44 
  45 /**
  46  * This class implements the shuffling logic that is required to adapt a method handle modelling a Java method into the
  47  * corresponding 'direct' native adapter (and viceversa, for upcalls). The shuffling is generally composed by two steps:
  48  * first we have to adapt incoming Java arguments into native values (or viceversa, in case of upcalls). Once that's done
  49  * a final permutation step is needed in order to push all the long arguments in front.
  50  */
  51 public class DirectSignatureShuffler {
  52 
  53     private static final MethodHandle LONG_TO_BOOLEAN;
  54     private static final MethodHandle BOOLEAN_TO_LONG;
  55     private static final MethodHandle LONG_TO_POINTER;
  56     private static final MethodHandle CALLBACK_TO_LONG;
  57     private static final MethodHandle LONG_TO_CALLBACK;
  58     private static final MethodHandle POINTER_TO_LONG;
  59     private static final MethodHandle STRUCT_TO_LONG;
  60     private static final MethodHandle STRUCT_TO_DOUBLE;
  61     private static final MethodHandle LONG_TO_STRUCT;
  62     private static final MethodHandle DOUBLE_TO_STRUCT;
  63 
  64     private static final int RET_POS = -1;
  65 
  66     static {
  67         try {
  68             LONG_TO_BOOLEAN = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "longToBoolean", MethodType.methodType(boolean.class, long.class));
  69             BOOLEAN_TO_LONG = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "booleanToLong", MethodType.methodType(long.class, boolean.class));
  70             LONG_TO_POINTER = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "longToPointer", MethodType.methodType(Pointer.class, LayoutType.class, long.class));
  71             POINTER_TO_LONG = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "pointerToLong", MethodType.methodType(long.class, Pointer.class));
  72             LONG_TO_CALLBACK = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "longToCallback", MethodType.methodType(Callback.class, Class.class, long.class));
  73             CALLBACK_TO_LONG = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "callbackToLong", MethodType.methodType(long.class, Callback.class));
  74             STRUCT_TO_LONG = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "structToLong", MethodType.methodType(long.class, Struct.class));
  75             STRUCT_TO_DOUBLE = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "structToDouble", MethodType.methodType(double.class, Struct.class));
  76             LONG_TO_STRUCT = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "longToStruct", MethodType.methodType(Struct.class, Class.class, long.class));
  77             DOUBLE_TO_STRUCT = MethodHandles.lookup().findStatic(DirectSignatureShuffler.class, "doubleToStruct", MethodType.methodType(Struct.class, Class.class, double.class));
  78         } catch (ReflectiveOperationException e) {
  79             throw new RuntimeException(e);
  80         }
  81     }
  82 
  83     private final ShuffleDirection direction;
  84     private final MethodType javaMethodType;
  85     private MethodType erasedMethodType = MethodType.methodType(void.class);
  86     private List<UnaryOperator<MethodHandle>> adapters = new ArrayList<>();
  87     private List<Integer> longPerms = new ArrayList<>();
  88     private List<Integer> doublePerms = new ArrayList<>();
  89 
  90     private DirectSignatureShuffler(CallingSequence callingSequence, NativeMethodType nmt, ShuffleDirection direction) {
  91         checkCallingSequence(callingSequence);
  92         this.direction = direction;
  93         this.javaMethodType = nmt.methodType();
  94         processType(RET_POS, nmt.returnType(), callingSequence.returnBindings(), direction.flip());
  95         for (int i = 0 ; i < javaMethodType.parameterCount() ; i++) {
  96             processType(i, nmt.parameterType(i), callingSequence.argumentBindings(i), direction);
  97         }
  98     }
  99 
 100     static DirectSignatureShuffler javaToNativeShuffler(CallingSequence callingSequence, NativeMethodType nmt) {
 101         return new DirectSignatureShuffler(callingSequence, nmt, ShuffleDirection.JAVA_TO_NATIVE);
 102     }
 103 
 104     static DirectSignatureShuffler nativeToJavaShuffler(CallingSequence callingSequence, NativeMethodType nmt) {
 105         return new DirectSignatureShuffler(callingSequence, nmt, ShuffleDirection.NATIVE_TO_JAVA);
 106     }
 107 
 108     MethodHandle adapt(MethodHandle mh) {
 109         if (direction == ShuffleDirection.JAVA_TO_NATIVE) {
 110             mh = MethodHandles.permuteArguments(mh, erasedMethodType, forwardPermutations());
 111         }
 112 
 113         for (UnaryOperator<MethodHandle> adapter : adapters) {
 114             mh = adapter.apply(mh);
 115         }
 116 
 117         if (direction == ShuffleDirection.NATIVE_TO_JAVA) {
 118             mh = MethodHandles.permuteArguments(mh, nativeMethodType(), reversePermutations());
 119         }
 120         return mh;
 121     }
 122 
 123     MethodType nativeMethodType() {
 124         MethodType mt = MethodType.methodType(erasedMethodType.returnType());
 125         mt = mt.appendParameterTypes(Collections.nCopies(longPerms.size(), long.class));
 126         return mt.appendParameterTypes(Collections.nCopies(doublePerms.size(), double.class));
 127     }
 128 
 129     MethodType javaMethodType() {
 130         return javaMethodType;
 131     }
 132 
 133     String nativeSigSuffix() {
 134         MethodType mt = nativeMethodType();
 135         return String.format("%s_%s",
 136                 desc(mt.returnType()),
 137                 mt.parameterCount() > 0 ?
 138                         mt.parameterList().stream().map(this::desc).collect(Collectors.joining()) :
 139                         "V");
 140 
 141     }
 142 
 143     private static void checkCallingSequence(CallingSequence callingSequence) {
 144         if (callingSequence.returnsInMemory() ||
 145                 !callingSequence.bindings(StorageClass.STACK_ARGUMENT_SLOT).isEmpty()) {
 146             throw new IllegalArgumentException("Unsupported non-scalarized calling sequence!");
 147         }
 148     }
 149 
 150     private void processType(int sigPos, LayoutType<?> lt, List<ArgumentBinding> bindings, ShuffleDirection direction) {
 151         Class<?> carrier = (Class<?>) Util.unboxIfNeeded(((LayoutTypeImpl<?>)lt).carrier());
 152         if (carrier.isPrimitive()) {
 153             if (carrier == long.class) {
 154                 updateNativeMethodType(sigPos, long.class);
 155             } else if (carrier == double.class) {
 156                 updateNativeMethodType(sigPos, double.class);
 157             } else if (carrier == float.class) {
 158                 updateNativeMethodType(sigPos, double.class);
 159                 adapters.add(direction.doubleAdapter(sigPos, carrier));
 160             } else if (carrier == boolean.class) {
 161                 updateNativeMethodType(sigPos, long.class);
 162                 adapters.add(direction.booleanAdapter(sigPos));
 163             } else if (carrier == void.class) {
 164                 assert sigPos == -1;
 165             } else {
 166                 updateNativeMethodType(sigPos, long.class);
 167                 adapters.add(direction.longAdapter(sigPos, carrier));
 168             }
 169         } else if (carrier == Pointer.class) {
 170             updateNativeMethodType(sigPos, long.class);
 171             adapters.add(direction.pointerAdapter(sigPos, lt));
 172         } else if (carrier == Callback.class) {
 173             updateNativeMethodType(sigPos, long.class);
 174             adapters.add(direction.callbackAdapter(sigPos, lt));
 175         } else if (Util.isCStruct(carrier)) {
 176             if (bindings.size() == 1) {
 177                 ArgumentBinding binding = bindings.get(0);
 178                 switch (binding.storage().getStorageClass()) {
 179                     case INTEGER_ARGUMENT_REGISTER:
 180                     case INTEGER_RETURN_REGISTER:
 181                         updateNativeMethodType(sigPos, long.class);
 182                         adapters.add(direction.longStructAdapter(sigPos, carrier));
 183                         break;
 184                     case VECTOR_ARGUMENT_REGISTER:
 185                     case VECTOR_RETURN_REGISTER:
 186                         updateNativeMethodType(sigPos, double.class);
 187                         adapters.add(direction.doubleStructAdapter(sigPos, carrier));
 188                         break;
 189                     default:
 190                         //non-register bindings should have already been discarded by now
 191                         throw new IllegalStateException("Cannot get here!");
 192                 }
 193             } else {
 194                 throw new IllegalArgumentException("Multi-value struct!");
 195             }
 196         } else {
 197             throw new IllegalArgumentException("Unsupported carrier: " + carrier);
 198         }
 199     }
 200 
 201     private void updateNativeMethodType(int sigPos, Class<?> carrier) {
 202         if (sigPos == -1) {
 203             erasedMethodType = erasedMethodType.changeReturnType(carrier);
 204         } else {
 205             erasedMethodType = erasedMethodType.appendParameterTypes(carrier);
 206             if (carrier == long.class) {
 207                 longPerms.add(sigPos);
 208             } else {
 209                 doublePerms.add(sigPos);
 210             }
 211         }
 212     }
 213 
 214     private int[] forwardPermutations() {
 215         return Stream.concat(longPerms.stream(), doublePerms.stream())
 216                 .mapToInt(x -> x)
 217                 .toArray();
 218     }
 219 
 220     private int[] reversePermutations() {
 221         int[] forward = forwardPermutations();
 222         int[] reverse = new int[forward.length];
 223         for (int i = 0 ; i < forward.length ; i++) {
 224             reverse[i] = lookup(forward, i);
 225         }
 226         return reverse;
 227     }
 228 
 229     private int lookup(int[] arr, int v) {
 230         for (int i = 0 ; i < arr.length ; i++) {
 231             if (arr[i] == v) return i;
 232         }
 233         throw new IllegalStateException();
 234     }
 235 
 236     private String desc(Class<?> clazz) {
 237         if (clazz == long.class) {
 238             return "J";
 239         } else if (clazz == double.class) {
 240             return "D";
 241         } else if (clazz == void.class) {
 242             return "V";
 243         } else {
 244             throw new IllegalStateException("Unexpected class: " + clazz);
 245         }
 246     }
 247 
 248     // adapter helpers
 249 
 250     enum ShuffleDirection {
 251         JAVA_TO_NATIVE(5),
 252         NATIVE_TO_JAVA(4);
 253 
 254         ShuffleDirection(int maxArity) {
 255             this.maxArity = maxArity;
 256         }
 257 
 258         int maxArity;
 259 
 260         ShuffleDirection flip() {
 261             return this == JAVA_TO_NATIVE ? NATIVE_TO_JAVA : JAVA_TO_NATIVE;
 262         }
 263 
 264         MethodHandle filterHandle(MethodHandle mh, int pos, MethodHandle filter) {
 265             return pos == RET_POS ?
 266                     MethodHandles.filterReturnValue(mh, filter) :
 267                     MethodHandles.filterArguments(mh, pos, filter);
 268         }
 269 
 270         UnaryOperator<MethodHandle> longAdapter(int pos, Class<?> carrier) {
 271             return mh -> filterHandle(mh, pos,
 272                     (this == JAVA_TO_NATIVE) ? primitiveToLong(carrier) : longToPrimitive(carrier));
 273         }
 274 
 275         UnaryOperator<MethodHandle> doubleAdapter(int pos, Class<?> carrier) {
 276             return mh -> filterHandle(mh, pos,
 277                     (this == JAVA_TO_NATIVE) ? primitiveToDouble(carrier) : doubleToPrimitive(carrier));
 278         }
 279 
 280         UnaryOperator<MethodHandle> booleanAdapter(int pos) {
 281             return mh -> filterHandle(mh, pos,
 282                     (this == JAVA_TO_NATIVE) ? BOOLEAN_TO_LONG : LONG_TO_BOOLEAN);
 283         }
 284 
 285         UnaryOperator<MethodHandle> pointerAdapter(int pos, LayoutType<?> type) {
 286             return mh -> filterHandle(mh, pos,
 287                     (this == JAVA_TO_NATIVE) ? POINTER_TO_LONG : LONG_TO_POINTER.bindTo(((LayoutTypeImpl<?>)type).pointeeType()));
 288         }
 289 
 290         UnaryOperator<MethodHandle> callbackAdapter(int pos, LayoutType<?> type) {
 291             return mh -> filterHandle(mh, pos,
 292                     (this == JAVA_TO_NATIVE) ? CALLBACK_TO_LONG : LONG_TO_CALLBACK.bindTo(((LayoutTypeImpl<?>)type).getFuncIntf()));
 293         }
 294 
 295         UnaryOperator<MethodHandle> longStructAdapter(int pos, Class<?> carrier) {
 296             return mh -> filterHandle(mh, pos,
 297                     (this == JAVA_TO_NATIVE)  ?
 298                             STRUCT_TO_LONG.asType(MethodType.methodType(long.class, carrier)) :
 299                             LONG_TO_STRUCT.bindTo(carrier).asType(MethodType.methodType(carrier, long.class)));
 300         }
 301 
 302         UnaryOperator<MethodHandle> doubleStructAdapter(int pos, Class<?> carrier) {
 303             return mh -> filterHandle(mh, pos,
 304                     (this == JAVA_TO_NATIVE)  ?
 305                             STRUCT_TO_DOUBLE.asType(MethodType.methodType(double.class, carrier)) :
 306                             DOUBLE_TO_STRUCT.bindTo(carrier).asType(MethodType.methodType(carrier, double.class)));
 307         }
 308     }
 309 
 310     private static boolean longToBoolean(long value) {
 311         return value != 0;
 312     }
 313 
 314     private static long booleanToLong(boolean value) {
 315         return value ? 1 : 0;
 316     }
 317 
 318     private static long pointerToLong(Pointer<?> value) throws IllegalAccessException {
 319         return value.addr();
 320     }
 321 
 322     private static Pointer<?> longToPointer(LayoutType<?> lt, long addr) {
 323         return addr == 0L ?
 324                 BoundedPointer.ofNull() :
 325                 BoundedPointer.createNativeVoidPointer(addr).cast(lt);
 326     }
 327 
 328     private static long callbackToLong(Callback<?> value) throws IllegalAccessException {
 329         return value.entryPoint().addr();
 330     }
 331 
 332     @SuppressWarnings({"unchecked", "rawtypes"})
 333     private static Callback<?> longToCallback(Class<?> funcClass, long addr) {
 334         if (addr == 0) {
 335             return Callback.ofNull();
 336         } else {
 337             return new CallbackImpl(BoundedPointer.createNativeVoidPointer(addr),
 338                 funcClass);
 339         }
 340     }
 341 
 342     private static long structToLong(Struct<?> struct) {
 343         return ((BoundedPointer<?>)struct.ptr()).getBits();
 344     }
 345 
 346     private static double structToDouble(Struct<?> struct) {
 347         return Double.longBitsToDouble(structToLong(struct));
 348     }
 349 
 350     @SuppressWarnings({"rawtypes", "unchecked"})
 351     private static Struct<?> longToStruct(Class carrier, long value) {
 352         Pointer<? extends Struct<?>> ptr = BoundedPointer.fromArray(LayoutType.ofStruct(carrier), new long[] { value });
 353         return ptr.get();
 354     }
 355 
 356     @SuppressWarnings({"rawtypes", "unchecked"})
 357     private static Struct<?> doubleToStruct(Class carrier, double value) {
 358         return longToStruct(carrier, Double.doubleToLongBits(value));
 359     }
 360 
 361     private static MethodHandle primitiveToLong(Class<?> carrier) {
 362         MethodHandle mh = MethodHandles.identity(long.class);
 363         return MethodHandles.explicitCastArguments(mh, MethodType.methodType(long.class, carrier));
 364     }
 365 
 366     private static MethodHandle longToPrimitive(Class<?> carrier) {
 367         MethodHandle mh = MethodHandles.identity(long.class);
 368         return MethodHandles.explicitCastArguments(mh, MethodType.methodType(carrier, long.class));
 369     }
 370 
 371     private static MethodHandle primitiveToDouble(Class<?> carrier) {
 372         MethodHandle mh = MethodHandles.identity(double.class);
 373         return MethodHandles.explicitCastArguments(mh, MethodType.methodType(double.class, carrier));
 374     }
 375 
 376     private static MethodHandle doubleToPrimitive(Class<?> carrier) {
 377         MethodHandle mh = MethodHandles.identity(double.class);
 378         return MethodHandles.explicitCastArguments(mh, MethodType.methodType(carrier, double.class));
 379     }
 380 
 381     // predicate: is fast path applicable?
 382 
 383     public static boolean acceptDowncall(NativeMethodType nmt, CallingSequence callingSequence) {
 384         return accept(nmt, callingSequence, ShuffleDirection.JAVA_TO_NATIVE);
 385     }
 386 
 387     public static boolean acceptUpcall(NativeMethodType nmt, CallingSequence callingSequence) {
 388         return accept(nmt, callingSequence, ShuffleDirection.NATIVE_TO_JAVA);
 389     }
 390 
 391     private static boolean accept(NativeMethodType nmt, CallingSequence callingSequence, ShuffleDirection direction) {
 392         if (nmt.isVarArgs() ||
 393                 callingSequence.returnsInMemory() ||
 394                 nmt.parameterCount() > direction.maxArity) return false;
 395 
 396         for (int i = 0 ; i < nmt.parameterCount(); i++) {
 397             List<ArgumentBinding> argumentBindings = callingSequence.argumentBindings(i);
 398             if (argumentBindings.size() != 1 ||
 399                     !isDirectBinding(argumentBindings.get(0))) {
 400                 return false;
 401             }
 402         }
 403 
 404         List<ArgumentBinding> returnBindings = callingSequence.returnBindings();
 405         return returnBindings.isEmpty() ||
 406                 (returnBindings.size() == 1 && isDirectBinding(returnBindings.get(0)));
 407     }
 408 
 409     private static boolean isDirectBinding(ArgumentBinding binding) {
 410         switch (binding.storage().getStorageClass()) {
 411             case X87_RETURN_REGISTER:
 412             case STACK_ARGUMENT_SLOT:
 413                 //arguments passed in memory not supported
 414                 return false;
 415             case VECTOR_ARGUMENT_REGISTER:
 416             case VECTOR_RETURN_REGISTER:
 417                 //avoid passing around floats as doubles as that leads to trouble
 418                 return (binding.argument().layout().bitsSize() / 8) == binding.storage().getSize();
 419             default:
 420                 return true;
 421         }
 422     }
 423 }