1 /*
   2  * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package jdk.internal.nicl;
  24 
  25 import jdk.internal.nicl.abi.ArgumentBinding;
  26 import jdk.internal.nicl.abi.CallingSequence;
  27 import jdk.internal.nicl.abi.Storage;
  28 import jdk.internal.nicl.abi.StorageClass;
  29 import jdk.internal.nicl.abi.SystemABI;
  30 import jdk.internal.nicl.abi.sysv.x64.Constants;
  31 import jdk.internal.nicl.types.*;
  32 
  33 import java.lang.invoke.MethodHandle;
  34 import java.lang.invoke.MethodHandles;
  35 import java.lang.invoke.MethodType;
  36 import java.lang.reflect.Method;
  37 import java.nicl.NativeTypes;
  38 import java.nicl.Scope;
  39 import java.nicl.layout.Function;
  40 import java.nicl.layout.Layout;
  41 import java.nicl.types.*;
  42 import java.nicl.types.Pointer;
  43 import java.util.ArrayList;
  44 import static sun.security.action.GetPropertyAction.privilegedGetProperty;
  45 
  46 public class UpcallHandler {
  47 
  48     private static final boolean DEBUG = Boolean.parseBoolean(
  49         privilegedGetProperty("jdk.internal.nicl.UpcallHandler.DEBUG"));
  50     private static final LayoutType<Long> LONG_LAYOUT_TYPE = NativeTypes.UINT64;
  51 
  52     private static final long MAX_STACK_ARG_BYTES = 64 * 1024; // FIXME: Arbitrary limitation for now...
  53 
  54     private static final Object HANDLERS_LOCK = new Object();
  55     private static final ArrayList<UpcallHandler> ID2HANDLER = new ArrayList<>();
  56 
  57     private final MethodHandle mh;
  58     private final Function ftype;
  59 
  60     private final UpcallStub stub;
  61 
  62     static {
  63         long size = 8;
  64         if (Constants.STACK_SLOT_SIZE != size) {
  65             throw new Error("Invalid size: " + Constants.STACK_SLOT_SIZE);
  66         }
  67         if (Constants.INTEGER_REGISTER_SIZE != size) {
  68             throw new Error("Invalid size: " + Constants.INTEGER_REGISTER_SIZE);
  69         }
  70         if ((Constants.VECTOR_REGISTER_SIZE % size) != 0) {
  71             throw new Error("Invalid size: " + Constants.VECTOR_REGISTER_SIZE);
  72         }
  73     }
  74 
  75     public static UpcallHandler make(Class<?> c, Object o) throws Throwable {
  76         if (!Util.isFunctionalInterface(c)) {
  77             throw new IllegalArgumentException("Class is not a @FunctionalInterface: " + c.getName());
  78         }
  79         if (o == null) {
  80             throw new NullPointerException();
  81         }
  82 
  83         if (!c.isInstance(o)) {
  84             throw new IllegalArgumentException("Object must implement FunctionalInterface class: " + c.getName());
  85         }
  86 
  87         Method ficMethod = Util.findFunctionalInterfaceMethod(c);
  88         Function ftype = Util.functionof(c);
  89 
  90         MethodType mt = MethodHandles.publicLookup().unreflect(ficMethod).type().dropParameterTypes(0, 1);
  91 
  92         MethodHandle mh = MethodHandles.publicLookup().findVirtual(c, "fn", mt);
  93 
  94         return UpcallHandler.make(mh.bindTo(o), ftype);
  95     }
  96 
  97     private static UpcallHandler make(MethodHandle mh, Function ftype) throws Throwable {
  98         synchronized (HANDLERS_LOCK) {
  99             int id = ID2HANDLER.size();
 100             UpcallHandler handler = new UpcallHandler(mh, ftype, id);
 101             ID2HANDLER.add(handler);
 102 
 103             if (DEBUG) {
 104                 System.err.println("Allocated upcall handler with id " + id);
 105             }
 106 
 107             return handler;
 108         }
 109     }
 110 
 111     public static void invoke(int id, long integers, long vectors, long stack, long integerReturn, long vectorReturn) {
 112         UpcallHandler handler;
 113 
 114         if (DEBUG) {
 115             System.err.println("UpcallHandler.invoke(" + id + ", ...) with " + ID2HANDLER.size() + " stubs allocated");
 116         }
 117 
 118         synchronized (HANDLERS_LOCK) {
 119             handler = ID2HANDLER.get(id);
 120         }
 121 
 122         try (Scope scope = Scope.newNativeScope()) {
 123             UpcallContext context = new UpcallContext(scope, integers, vectors, stack, integerReturn, vectorReturn);
 124             handler.invoke(context);
 125         }
 126     }
 127 
 128     private UpcallHandler(MethodHandle mh, Function ftype, int id) throws Throwable {
 129         this.mh = mh;
 130         this.ftype = ftype;
 131         this.stub = new UpcallStub(id);
 132     }
 133 
 134     public Pointer<?> getNativeEntryPoint() {
 135         return stub.getEntryPoint();
 136     }
 137 
 138     static class UpcallContext {
 139 
 140         private final Pointer<Long> integers;
 141         private final Pointer<Long> vectors;
 142         private final Pointer<Long> stack;
 143         private final Pointer<Long> integerReturns;
 144         private final Pointer<Long> vectorReturns;
 145 
 146         UpcallContext(Scope scope, long integers, long vectors, long stack, long integerReturn, long vectorReturn) {
 147             this.integers = new BoundedPointer<>(LONG_LAYOUT_TYPE, new BoundedMemoryRegion(integers, Constants.MAX_INTEGER_ARGUMENT_REGISTERS * Constants.INTEGER_REGISTER_SIZE, scope), 0, BoundedMemoryRegion.MODE_R);
 148             this.vectors = new BoundedPointer<>(LONG_LAYOUT_TYPE, new BoundedMemoryRegion(vectors, Constants.MAX_VECTOR_ARGUMENT_REGISTERS * Constants.VECTOR_REGISTER_SIZE, scope), 0, BoundedMemoryRegion.MODE_R);
 149             this.stack = new BoundedPointer<>(LONG_LAYOUT_TYPE, new BoundedMemoryRegion(stack, MAX_STACK_ARG_BYTES, scope), 0, BoundedMemoryRegion.MODE_R);
 150             this.integerReturns = new BoundedPointer<>(LONG_LAYOUT_TYPE, new BoundedMemoryRegion(integerReturn, Constants.MAX_INTEGER_RETURN_REGISTERS * Constants.INTEGER_REGISTER_SIZE, scope), 0, BoundedMemoryRegion.MODE_W);
 151             this.vectorReturns = new BoundedPointer<>(LONG_LAYOUT_TYPE, new BoundedMemoryRegion(vectorReturn, Constants.MAX_VECTOR_RETURN_REGISTERS * Constants.VECTOR_REGISTER_SIZE, scope), 0, BoundedMemoryRegion.MODE_W);
 152         }
 153 
 154         Pointer<Long> getPtr(Storage storage) {
 155             switch (storage.getStorageClass()) {
 156             case INTEGER_ARGUMENT_REGISTER:
 157                 return integers.offset(storage.getStorageIndex());
 158             case VECTOR_ARGUMENT_REGISTER:
 159                 return vectors.offset(storage.getStorageIndex() * Constants.VECTOR_REGISTER_SIZE / 8);
 160             case STACK_ARGUMENT_SLOT:
 161                 return stack.offset(storage.getStorageIndex());
 162 
 163             case INTEGER_RETURN_REGISTER:
 164                 return integerReturns.offset(storage.getStorageIndex());
 165             case VECTOR_RETURN_REGISTER:
 166                 return vectorReturns.offset(storage.getStorageIndex() * Constants.VECTOR_REGISTER_SIZE / 8);
 167             default:
 168                 throw new Error("Unhandled storage: " + storage);
 169             }
 170         }
 171     }
 172 
 173     private Object boxArgument(Scope scope, UpcallContext context, Struct<?>[] structs, ArgumentBinding binding) throws IllegalAccessException {
 174         Class<?> carrierType = binding.getMember().getCarrierType(mh.type());
 175 
 176         Pointer<Long> src = context.getPtr(binding.getStorage());
 177 
 178         if (DEBUG) {
 179             System.err.println("boxArgument carrier type: " + carrierType);
 180         }
 181 
 182 
 183         if (Util.isCStruct(carrierType)) {
 184             int index = binding.getMember().getArgumentIndex();
 185             Struct<?> r = structs[index];
 186             if (r == null) {
 187                 /*
 188                  * FIXME (STRUCT-LIFECYCLE):
 189                  *
 190                  * Leak memory for now
 191                  */
 192                 scope = Scope.newNativeScope();
 193 
 194                 @SuppressWarnings({"rawtypes", "unchecked"})
 195                 Struct<?> rtmp = scope.allocateStruct((Class)carrierType);
 196 
 197                 structs[index] = r = rtmp;
 198             }
 199 
 200             if (DEBUG) {
 201                 System.out.println("Populating struct at arg index " + index + " at offset 0x" + Long.toHexString(binding.getOffset()));
 202             }
 203 
 204             if ((binding.getOffset() % LONG_LAYOUT_TYPE.bytesSize()) != 0) {
 205                 throw new Error("Invalid offset: " + binding.getOffset());
 206             }
 207             Pointer<Long> dst = r.ptr().cast(LONG_LAYOUT_TYPE).offset(binding.getOffset() / LONG_LAYOUT_TYPE.bytesSize());
 208 
 209             if (DEBUG) {
 210                 System.err.println("Copying struct data, value: 0x" + Long.toHexString(src.get()));
 211             }
 212 
 213             Util.copy(src, dst, binding.getStorage().getSize());
 214 
 215             return r;
 216         } else {
 217             return src.cast(Util.makeType(carrierType, src.type().layout())).get();
 218         }
 219     }
 220 
 221     private Object[] boxArguments(Scope scope, UpcallContext context, CallingSequence callingSequence) {
 222         Object[] args = new Object[mh.type().parameterCount()];
 223 
 224         Struct<?>[] structs = new Struct<?>[mh.type().parameterCount()];
 225 
 226         if (DEBUG) {
 227             System.out.println("boxArguments " + callingSequence.asString());
 228         }
 229 
 230         for (StorageClass c : Constants.ARGUMENT_STORAGE_CLASSES) {
 231             int skip = (c == StorageClass.INTEGER_ARGUMENT_REGISTER && callingSequence.returnsInMemory()) ? 1 : 0;
 232             callingSequence
 233                 .getBindings(c)
 234                 .stream()
 235                 .skip(skip)
 236                 .filter(binding -> binding != null)
 237                 .forEach(binding -> {
 238                     try {
 239                         args[binding.getMember().getArgumentIndex()] = boxArgument(scope, context, structs, binding);
 240                     } catch (IllegalAccessException e) {
 241                         throw new IllegalArgumentException("Failed to box argument", e);
 242                     }
 243                 });
 244         }
 245 
 246         return args;
 247     }
 248 
 249     private void unboxReturn(Class<?> c, UpcallContext context, ArgumentBinding binding, Object o) throws IllegalAccessException {
 250         if (DEBUG) {
 251             System.out.println("unboxReturn " + c.getName());
 252             System.out.println(binding.toString());
 253         }
 254 
 255         Pointer<Long> dst = context.getPtr(binding.getStorage());
 256 
 257         if (Util.isCStruct(c)) {
 258             Function ft = Function.of(Util.layoutof(c), false, new Layout[0]);
 259             boolean returnsInMemory = SystemABI.getInstance().arrangeCall(ft).returnsInMemory();
 260 
 261             Struct<?> struct = (Struct<?>) o;
 262 
 263             Pointer<Long> src = struct.ptr().cast(LONG_LAYOUT_TYPE);
 264 
 265             if (returnsInMemory) {
 266                 // the first integer argument register contains a pointer to caller allocated struct
 267                 long structAddr = context.getPtr(new Storage(StorageClass.INTEGER_ARGUMENT_REGISTER, 0, Constants.INTEGER_REGISTER_SIZE)).get();
 268                 long size = Util.alignUp(ftype.returnLayout().get().bitsSize() / 8, 8);
 269                 Pointer<?> dstStructPtr = new BoundedPointer<>(Util.makeType(c, ftype.returnLayout().get()), new BoundedMemoryRegion(structAddr, size));
 270                 try {
 271                     ((BoundedPointer<?>) dstStructPtr).type.setter().invoke(dstStructPtr, o);
 272                 } catch (Throwable ex) {
 273                     throw new IllegalStateException(ex);
 274                 }
 275             } else {
 276                 if ((binding.getOffset() % LONG_LAYOUT_TYPE.bytesSize()) != 0) {
 277                     throw new Error("Invalid offset: " + binding.getOffset());
 278                 }
 279                 Pointer<Long> srcPtr = src.offset(binding.getOffset() / LONG_LAYOUT_TYPE.bytesSize());
 280                 Util.copy(srcPtr, dst, binding.getStorage().getSize());
 281             }
 282         } else {
 283             try {
 284                 dst.cast(Util.makeType(c, ftype.returnLayout().get())).type().setter().invoke(dst, o);
 285             } catch (Throwable ex) {
 286                 throw new IllegalStateException(ex);
 287             }
 288         }
 289     }
 290 
 291     private void invoke(UpcallContext context) {
 292         try (Scope scope = Scope.newNativeScope()) {
 293             // FIXME: Handle varargs upcalls here
 294             CallingSequence callingSequence = SystemABI.getInstance().arrangeCall(ftype);
 295 
 296             if (DEBUG) {
 297                 System.err.println("=== UpcallHandler.invoke ===");
 298                 System.err.println(callingSequence.asString());
 299             }
 300 
 301             Object[] args = boxArguments(scope, context, callingSequence);
 302 
 303             Object o = mh.asSpreader(Object[].class, args.length).invoke(args);
 304 
 305             if (mh.type().returnType() != void.class) {
 306                 for (StorageClass c : Constants.RETURN_STORAGE_CLASSES) {
 307                     for (ArgumentBinding binding : callingSequence.getBindings(c)) {
 308                         unboxReturn(mh.type().returnType(), context, binding, o);
 309                     }
 310                 }
 311             }
 312         } catch (Throwable t) {
 313             throw new RuntimeException(t);
 314         }
 315     }
 316 }