1 /*
   2  * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package jdk.internal.foreign;
  24 
  25 import java.foreign.NativeMethodType;
  26 import java.foreign.NativeTypes;
  27 import java.foreign.Scope;
  28 import java.foreign.annotations.NativeCallback;
  29 import java.foreign.annotations.NativeGetter;
  30 import java.foreign.annotations.NativeHeader;
  31 import java.foreign.annotations.NativeStruct;
  32 import java.foreign.layout.Address;
  33 import java.foreign.layout.Function;
  34 import java.foreign.layout.Layout;
  35 import java.foreign.layout.Sequence;
  36 import java.foreign.layout.Value;
  37 import java.foreign.memory.Array;
  38 import java.foreign.memory.Callback;
  39 import java.foreign.memory.LayoutType;
  40 import java.foreign.memory.Pointer;
  41 import java.foreign.memory.Struct;
  42 import java.lang.invoke.MethodHandle;
  43 import java.lang.invoke.MethodHandles;
  44 import java.lang.invoke.MethodType;
  45 import java.lang.reflect.GenericArrayType;
  46 import java.lang.reflect.Method;
  47 import java.lang.reflect.Modifier;
  48 import java.lang.reflect.ParameterizedType;
  49 import java.lang.reflect.Type;
  50 import java.lang.reflect.TypeVariable;
  51 import java.lang.reflect.WildcardType;
  52 import java.nio.Buffer;
  53 import java.nio.ByteBuffer;
  54 import java.util.HashSet;
  55 import java.util.Optional;
  56 import java.util.Set;
  57 import java.util.function.LongFunction;
  58 import java.util.stream.Stream;
  59 import jdk.internal.foreign.memory.DescriptorParser;
  60 import jdk.internal.foreign.memory.Types;
  61 import jdk.internal.misc.Unsafe;
  62 
  63 public final class Util {
  64 
  65     private static final Unsafe UNSAFE = Unsafe.getUnsafe();
  66 
  67     public static final long BYTE_BUFFER_BASE;
  68     public static final long BUFFER_ADDRESS;
  69 
  70     static {
  71         try {
  72             BYTE_BUFFER_BASE = UNSAFE.objectFieldOffset(ByteBuffer.class.getDeclaredField("hb"));
  73             BUFFER_ADDRESS = UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address"));
  74         }
  75         catch (Exception e) {
  76             throw new InternalError(e);
  77         }
  78     }
  79 
  80     private Util() {
  81     }
  82 
  83     public static long addUnsignedExact(long a, long b) {
  84         long result = a + b;
  85         if(Long.compareUnsigned(result, a) < 0) {
  86             throw new ArithmeticException(
  87                 "Unsigned overflow: "
  88                     + Long.toUnsignedString(a) + " + "
  89                     + Long.toUnsignedString(b));
  90         }
  91 
  92         return result;
  93     }
  94 
  95     public static Object getBufferBase(ByteBuffer bb) {
  96         return UNSAFE.getReference(bb, BYTE_BUFFER_BASE);
  97     }
  98 
  99     public static long getBufferAddress(ByteBuffer bb) {
 100         return UNSAFE.getLong(bb, BUFFER_ADDRESS);
 101     }
 102 
 103     public static long alignUp(long n, long alignment) {
 104         return (n + alignment - 1) & ~(alignment - 1);
 105     }
 106 
 107     public static boolean isCStruct(Class<?> clz) {
 108         return Struct.class.isAssignableFrom(clz) &&
 109                 clz.isAnnotationPresent(NativeStruct.class);
 110     }
 111 
 112     public static boolean isCLibrary(Class<?> clz) {
 113         return clz.isAnnotationPresent(NativeHeader.class);
 114     }
 115 
 116     public static Class<?>[] resolutionContextFor(Class<?> clz) {
 117         if (isCallback(clz)) {
 118             return clz.getAnnotation(NativeCallback.class).resolutionContext();
 119         } else if (isCStruct(clz)) {
 120             return clz.getAnnotation(NativeStruct.class).resolutionContext();
 121         } else if (isCLibrary(clz)) {
 122             return clz.getAnnotation(NativeHeader.class).resolutionContext();
 123         } else {
 124             return null;
 125         }
 126     }
 127 
 128     public static Layout layoutof(Class<?> c) {
 129         String layout;
 130         if (c.isAnnotationPresent(NativeStruct.class)) {
 131             layout = c.getAnnotation(NativeStruct.class).value();
 132         } else {
 133             throw new IllegalArgumentException("@NativeStruct or @NativeType expected: " + c);
 134         }
 135         return new DescriptorParser(layout).parseLayout();
 136     }
 137 
 138     public static Function functionof(Class<?> c) {
 139         if (! c.isAnnotationPresent(NativeCallback.class)) {
 140             throw new IllegalArgumentException("@NativeCallback expected: " + c);
 141         }
 142         NativeCallback nc = c.getAnnotation(NativeCallback.class);
 143         return new DescriptorParser(nc.value()).parseFunction();
 144     }
 145 
 146     static MethodType methodTypeFor(Method method) {
 147         return MethodType.methodType(method.getReturnType(), method.getParameterTypes());
 148     }
 149 
 150     static boolean isCompatible(Method method, Function function) {
 151         // same return kind (void or non-void)
 152         boolean isNonVoidMethod = method.getReturnType() != void.class;
 153         if (isNonVoidMethod != function.returnLayout().isPresent())  {
 154             return false;
 155         }
 156 
 157         //same vararg-ness
 158         if(method.isVarArgs() != function.isVariadic()) {
 159             return false;
 160         }
 161 
 162         //same arity (take Java varargs array into account)
 163         int expectedArity = function.argumentLayouts().size() + (function.isVariadic() ? 1 : 0);
 164         return method.getParameterCount() == expectedArity;
 165     }
 166 
 167     public static void checkCompatible(Method method, Function function) {
 168         if (!isCompatible(method, function)) {
 169             throw new IllegalArgumentException(
 170                 "Java method signature and native layout not compatible: " + method + " : " + function);
 171         }
 172     }
 173 
 174     public static boolean isCallback(Class<?> c) {
 175         return c.isAnnotationPresent(NativeCallback.class);
 176     }
 177 
 178     public static Method findFunctionalInterfaceMethod(Class<?> c) {
 179         Optional<Method> methodOpt = Optional.empty();
 180         if (c.isAnnotationPresent(NativeCallback.class)) {
 181             methodOpt = Stream.of(c.getDeclaredMethods())
 182                 .filter(m -> (m.getModifiers() & (Modifier.ABSTRACT | Modifier.PUBLIC)) != 0)
 183                 .findFirst();
 184         } else {
 185             throw new IllegalArgumentException("Class is not a @NativeCallback: " + c.getName());
 186         }
 187         return methodOpt.orElseThrow(IllegalStateException::new);
 188     }
 189 
 190     public static Class<?> findUniqueCallback(Class<?> cls) {
 191         Set<Class<?>> candidates = new HashSet<>();
 192         findUniqueCallbackInternal(cls, candidates);
 193         return (candidates.size() == 1) ?
 194                 candidates.iterator().next() :
 195                 null;
 196     }
 197 
 198     private static void findUniqueCallbackInternal(Class<?> cls, Set<Class<?>> candidates) {
 199         if (isCallback(cls)) {
 200             candidates.add(cls);
 201         }
 202         Class<?> sup = cls.getSuperclass();
 203         if (sup != null) {
 204             findUniqueCallbackInternal(sup, candidates);
 205         }
 206         for (Class<?> i : cls.getInterfaces()) {
 207             findUniqueCallbackInternal(i, candidates);
 208         }
 209     }
 210 
 211     @SuppressWarnings({"unchecked", "rawtypes"})
 212     public static LayoutType<?> makeType(Type carrier, Layout layout) {
 213         carrier = unboxIfNeeded(carrier);
 214         if (carrier == byte.class) {
 215             return LayoutType.ofByte(layout);
 216         } else if (carrier == void.class) {
 217             return LayoutType.ofVoid(layout);
 218         } else if (carrier == boolean.class) {
 219             return LayoutType.ofBoolean(layout);
 220         } else if (carrier == short.class) {
 221             return LayoutType.ofShort(layout);
 222         } else if (carrier == int.class) {
 223             return LayoutType.ofInt(layout);
 224         } else if (carrier == char.class) {
 225             return LayoutType.ofChar(layout);
 226         } else if (carrier == long.class) {
 227             return LayoutType.ofLong(layout);
 228         } else if (carrier == float.class) {
 229             return LayoutType.ofFloat(layout);
 230         } else if (carrier == double.class) {
 231             return LayoutType.ofDouble(layout);
 232         } else if (carrier == Pointer.class) {
 233             return NativeTypes.VOID.pointer((Address)layout);
 234         } else if (Pointer.class.isAssignableFrom(erasure(carrier))) {
 235             Type targ = extractTypeArgument(carrier);
 236             Address addr = (Address)layout;
 237             if (targ == null || addr.layout().isEmpty()) {
 238                 return NativeTypes.VOID.pointer(addr);
 239             } else {
 240                 return makeType(targ, addr.layout().get()).pointer(addr);
 241             }
 242         } else if (Array.class.isAssignableFrom(erasure(carrier))) {
 243             Type targ = extractTypeArgument(carrier);
 244             Sequence seq = (Sequence)layout;
 245             if (targ == null) {
 246                 return NativeTypes.VOID.array();
 247             } else {
 248                 return makeType(targ, seq.element()).array(seq.elementsSize());
 249             }
 250         } else if (isCStruct(erasure(carrier))) {
 251             return LayoutType.ofStruct((Class) carrier);
 252         } else if (Callback.class.isAssignableFrom(erasure(carrier))) {
 253             Type targ = extractTypeArgument(carrier);
 254             if (targ == null) {
 255                 throw new IllegalStateException("Invalid callback carrier: " + carrier.getTypeName());
 256             }
 257             return LayoutType.ofFunction((Address) layout, erasure(targ));
 258         } else {
 259             throw new IllegalStateException("Unknown carrier: " + carrier.getTypeName());
 260         }
 261     }
 262 
 263     static Type extractTypeArgument(Type t) {
 264         if (t instanceof ParameterizedType) {
 265             Type arg = ((ParameterizedType)t).getActualTypeArguments()[0];
 266             if (arg == Void.class) {
 267                 return null;
 268             } else if (arg instanceof WildcardType) {
 269                 Type[] lo =  ((WildcardType) arg).getLowerBounds();
 270                 Type[] hi =  ((WildcardType) arg).getUpperBounds();
 271                 if (lo.length == 1) {
 272                     // Lower bound is zero-elem array if this is '?' or '? extends' wildcards.
 273                     // Otherwise it's one element array.
 274                     return lo[0];
 275                 } else if (hi.length == 2 && hi[0] == Object.class) {
 276                     // Upper bound is always guaranteed to have at least one element, but
 277                     // the first bound can be j.l.Object if an interface bound is used. In that case,
 278                     // skip Object, and return the interface bound.
 279                     return hi[1];
 280                 } else if (hi.length == 1) {
 281                     // Return either the non-Object class bound, or null.
 282                     return hi[0] == Object.class ?
 283                             null : hi[0];
 284                 } else {
 285                     //unsupported combination of upper/lower bounds.
 286                     throw new IllegalStateException("Unsupported wildcard type-argument: " + arg.getTypeName());
 287                 }
 288             } else {
 289                 return arg;
 290             }
 291         } else {
 292             return null;
 293         }
 294     }
 295 
 296     public static Class<?> erasure(Type type) {
 297         if (type instanceof ParameterizedType) {
 298             return (Class<?>)((ParameterizedType)type).getRawType();
 299         } else if (type instanceof GenericArrayType) {
 300             return java.lang.reflect.Array.newInstance(erasure(((GenericArrayType)type).getGenericComponentType()), 0).getClass();
 301         } else if (type instanceof TypeVariable<?>) {
 302             return erasure(((TypeVariable<?>)type).getBounds()[0]);
 303         } else {
 304             return (Class<?>)type;
 305         }
 306     }
 307 
 308     public static Type unboxIfNeeded(Type clazz) {
 309         if (clazz == Boolean.class) {
 310             return boolean.class;
 311         } else if (clazz == Void.class) {
 312             return void.class;
 313         } else if (clazz == Byte.class) {
 314             return byte.class;
 315         } else if (clazz == Character.class) {
 316             return char.class;
 317         } else if (clazz == Short.class) {
 318             return short.class;
 319         } else if (clazz == Integer.class) {
 320             return int.class;
 321         } else if (clazz == Long.class) {
 322             return long.class;
 323         } else if (clazz == Float.class) {
 324             return float.class;
 325         } else if (clazz == Double.class) {
 326             return double.class;
 327         } else {
 328             return clazz;
 329         }
 330     }
 331 
 332     public static <Z> Pointer<Z> unsafeCast(Pointer<?> ptr, LayoutType<Z> layoutType) {
 333         return ptr.cast(NativeTypes.VOID).cast(layoutType);
 334     }
 335 
 336     public static MethodType checkNoArrays(MethodHandles.Lookup lookup, Class<?> fi) {
 337         try {
 338             return checkNoArrays(lookup.unreflect(findFunctionalInterfaceMethod(fi)).type());
 339         } catch (ReflectiveOperationException ex) {
 340             throw new IllegalStateException(ex);
 341         }
 342     }
 343     public static MethodType checkNoArrays(MethodType mt) {
 344         if (Stream.concat(Stream.of(mt.returnType()), mt.parameterList().stream())
 345                 .anyMatch(c -> Array.class.isAssignableFrom(c))) {
 346             //arrays in functions not supported!
 347             throw new UnsupportedOperationException("Array carriers not supported in functions");
 348         }
 349         return mt;
 350     }
 351 
 352     public static Pointer<?> getSyntheticCallbackAddress(Object o) {
 353         // First field
 354         return (Pointer<?>) UNSAFE.getReference(o, 0L);
 355     }
 356 
 357     public static <Z> Z withOffHeapAddress(Pointer<?> p, LongFunction<Z> longFunction) {
 358         try {
 359             try {
 360                 //address
 361                 return longFunction.apply(p.addr());
 362             } catch (UnsupportedOperationException ex) {
 363                 //heap pointer
 364                 try (Scope sc = Scope.globalScope().fork()) {
 365                     Pointer<?> offheapPtr = sc.allocate(p.type());
 366                     Pointer.copy(p, offheapPtr, p.type().bytesSize());
 367                     Z z = longFunction.apply(offheapPtr.addr());
 368                     Pointer.copy(offheapPtr, p, p.type().bytesSize());
 369                     return z;
 370                 }
 371             }
 372         } catch (Throwable ex) {
 373             throw new IllegalStateException(ex);
 374         }
 375     }
 376 
 377     public static MethodHandle getCallbackMH(Method m) {
 378         try {
 379             //Note: we need the call to setAccessible because the method to unreflect might belong to a module
 380             //that java.base does not read. However, this poses no security threat, given that the reflective method
 381             //we update here is not leaked outside (and is created insider the binder itself).
 382             m.setAccessible(true);
 383             MethodHandle mh = MethodHandles.lookup().unreflect(m);
 384             Util.checkNoArrays(mh.type());
 385             return mh;
 386         } catch (Throwable ex) {
 387             throw new IllegalStateException(ex);
 388         }
 389     }
 390 
 391     public static Function getResolvedFunction(Class<?> nativeCallback, Method m) {
 392         LayoutResolver resolver = LayoutResolver.get(nativeCallback);
 393         return resolver.resolve(Util.functionof(nativeCallback));
 394     }
 395 
 396     public static NativeMethodType nativeMethodType(Function function, Method method) {
 397         checkCompatible(method, function);
 398 
 399         LayoutType<?> ret = function.returnLayout()
 400                 .<LayoutType<?>>map(l -> makeType(method.getGenericReturnType(), l))
 401                 .orElse(NativeTypes.VOID);
 402 
 403         // Use function argument size and ignore last argument from method for vararg function
 404         LayoutType<?>[] args = new LayoutType<?>[function.argumentLayouts().size()];
 405         for (int i = 0; i < args.length; i++) {
 406             args[i] = makeType(method.getGenericParameterTypes()[i], function.argumentLayouts().get(i));
 407         }
 408 
 409         return NativeMethodType.of(function.isVariadic(), ret, args);
 410     }
 411 
 412     @SuppressWarnings("unchecked")
 413     public static Class<?> findStructInterface(Struct<?> struct) {
 414         for (Class<?> intf : struct.getClass().getInterfaces()) {
 415             if (intf.isAnnotationPresent(NativeStruct.class)) {
 416                 return intf;
 417             }
 418         }
 419         throw new IllegalStateException("Can not find struct interface");
 420     }
 421 
 422 
 423     public static Method getterByName(Class<?> cls, String name) {
 424         for (Method m : cls.getDeclaredMethods()) {
 425             NativeGetter ng = m.getAnnotation(NativeGetter.class);
 426             if (ng != null && ng.value().equals(name)) {
 427                 return m;
 428             }
 429         }
 430         return null;
 431     }
 432 
 433     public static Layout requireNoEndianLayout(Layout layout) {
 434         if (layout instanceof Value) {
 435             if (!((Value)layout).isNativeByteOrder()) {
 436                 throw new IllegalArgumentException("Non-platform endianess not allowed in method argument/return value");
 437             }
 438         }
 439         return layout;
 440     }
 441 
 442     public static long unsafeArrayBase(Class<?> arrayClass) {
 443         return UNSAFE.arrayBaseOffset(arrayClass);
 444     }
 445 
 446     public static long unsafeArrayScale(Class<?> arrayClass) {
 447         return UNSAFE.arrayIndexScale(arrayClass);
 448     }
 449 
 450     public static int sizeof(Class<?> carrier) {
 451         if (carrier == byte.class || carrier == boolean.class) {
 452             return 8;
 453         } else if (carrier == char.class || carrier == short.class) {
 454             return 16;
 455         } else if (carrier == int.class || carrier == float.class) {
 456             return 32;
 457         } else if (carrier == long.class || carrier == double.class) {
 458             return 64;
 459         } else {
 460             throw new IllegalStateException("Unexpected carrier: " + carrier.getName());
 461         }
 462     }
 463 }