1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @modules jdk.jextract
  27  * @build TestUpcall
  28  *
  29  * @run testng/othervm -Djdk.internal.foreign.UpcallHandler.FASTPATH=none TestUpcall
  30  * @run testng/othervm TestUpcall
  31  */
  32 
  33 import org.testng.annotations.*;
  34 
  35 import java.foreign.Libraries;
  36 import java.foreign.NativeTypes;
  37 import java.foreign.Scope;
  38 import java.foreign.layout.Group;
  39 import java.foreign.layout.Layout;
  40 import java.foreign.layout.Padding;
  41 import java.foreign.memory.Pointer;
  42 import java.foreign.memory.Struct;
  43 import java.lang.invoke.MethodHandles;
  44 import java.lang.reflect.Method;
  45 import java.lang.reflect.ParameterizedType;
  46 import java.lang.reflect.Proxy;
  47 import java.nio.file.Path;
  48 import java.util.ArrayList;
  49 import java.util.List;
  50 import java.util.function.Consumer;
  51 import java.util.stream.Stream;
  52 
  53 import static org.testng.Assert.*;
  54 
  55 public class TestUpcall extends JextractToolRunner {
  56 
  57     final static int MAX_CODE = 20;
  58 
  59     public static class UpcallTest {
  60 
  61         private final Class<?> headerCls;
  62         private final Object lib;
  63     
  64         public UpcallTest(Class<?> headerCls, Object lib) {
  65             this.headerCls = headerCls;
  66             this.lib = lib;
  67         }
  68     
  69         @Test(dataProvider = "getArgs")
  70         public void testUpCall(String mName, @NoInjection Method m)  throws ReflectiveOperationException {
  71             System.err.print("Calling " + mName + "...");
  72             try(Scope scope = Scope.newNativeScope()) {
  73                 List<Consumer<Object>> checks = new ArrayList<>();
  74                 Object res = m.invoke(lib, makeArgs(scope, m, checks));
  75                 if (m.getReturnType() != void.class) {
  76                     checks.forEach(c -> c.accept(res));
  77                 }
  78             }
  79             System.err.println("...done");
  80         }
  81     
  82         @DataProvider
  83         public Object[][] getArgs() {
  84             return Stream.of(headerCls.getDeclaredMethods())
  85                     .map(m -> new Object[]{ m.getName(), m })
  86                     .toArray(Object[][]::new);
  87         }
  88     
  89     }
  90 
  91     @Factory
  92     public Object[] getTests() throws ReflectiveOperationException {
  93         List<UpcallTest> res = new ArrayList<>();
  94         for (int i = 0 ; i < MAX_CODE ; i++) {
  95             Path clzPath = getOutputFilePath("libTestUpcall.jar");
  96             run("-o", clzPath.toString(),
  97                     "--exclude-symbols", filterFor(i),
  98                     getInputFilePath("libTestUpcall.h").toString()).checkSuccess();
  99             Class<?> headerCls = loadClass("libTestUpcall", clzPath);
 100             Object lib = Libraries.bind(headerCls, Libraries.loadLibrary(MethodHandles.lookup(), "TestUpcall"));
 101             res.add(new UpcallTest(headerCls, lib));
 102         }
 103         if(res.isEmpty())
 104             throw new RuntimeException("Could not generate any tests");
 105         return res.toArray();
 106     }
 107 
 108     static Object[] makeArgs(Scope sc, Method m, List<Consumer<Object>> checks) throws ReflectiveOperationException {
 109         Class<?>[] params = m.getParameterTypes();
 110         Object[] args = new Object[params.length];
 111         for (int i = 0 ; i < params.length - 1 ; i++) {
 112             args[i] = makeArg(sc, params[i], checks, i == 0);
 113         }
 114         args[params.length - 1] = makeCallback(sc, m);
 115         return args;
 116     }
 117 
 118     @SuppressWarnings("unchecked")
 119     static Object makeArg(Scope sc, Class<?> carrier, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
 120         if (Struct.class.isAssignableFrom(carrier)) {
 121             Struct<?> str = sc.allocateStruct((Class)carrier);
 122             initStruct(sc, str, checks, check);
 123             return str;
 124         } else if (carrier == int.class) {
 125             if (check) {
 126                 checks.add(o -> assertEquals(o, 42));
 127             }
 128             return 42;
 129         } else if (carrier == float.class) {
 130             if (check) {
 131                 checks.add(o -> assertEquals(o, 12f));
 132             }
 133             return 12f;
 134         } else if (carrier == double.class) {
 135             if (check) {
 136                 checks.add(o -> assertEquals(o, 24d));
 137             }
 138             return 24d;
 139         } else if (carrier == Pointer.class) {
 140             Pointer<?> p = sc.allocate(NativeTypes.INT32);
 141             if (check) {
 142                 checks.add(o -> {
 143                     try {
 144                         assertEquals(((Pointer<?>)o).addr(), p.addr());
 145                     } catch (Throwable ex) {
 146                         throw new IllegalStateException(ex);
 147                     }
 148                 });
 149             }
 150             return p;
 151         } else {
 152             throw new IllegalStateException("Unexpected carrier: " + carrier);
 153         }
 154     }
 155 
 156     static void initStruct(Scope sc, Struct<?> str, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
 157         Group g = (Group)str.ptr().type().layout();
 158         for (Layout l : g.elements()) {
 159             if (l instanceof Padding) continue;
 160             Method getter = str.getClass().getDeclaredMethod(l.annotations().get("get"));
 161             Class<?> carrier = getter.getReturnType();
 162             Method setter = str.getClass().getDeclaredMethod(l.annotations().get("set"), carrier);
 163             List<Consumer<Object>> fieldsCheck = new ArrayList<>();
 164             Object value = makeArg(sc, carrier, fieldsCheck, check);
 165             //set value
 166             setter.invoke(str, value);
 167             //add check
 168             if (check) {
 169                 assertTrue(fieldsCheck.size() == 1);
 170                 checks.add(o -> {
 171                     try {
 172                         fieldsCheck.get(0).accept(getter.invoke(o));
 173                     } catch (Throwable ex) {
 174                         throw new IllegalStateException(ex);
 175                     }
 176                 });
 177             }
 178         }
 179     }
 180 
 181     @SuppressWarnings("unchecked")
 182     static Object makeCallback(Scope sc, Method m) {
 183         ParameterizedType callbackParam = ((ParameterizedType)m.getGenericParameterTypes()[m.getParameterCount() - 1]);
 184         Class<?> callbackType = (Class<?>)callbackParam.getActualTypeArguments()[0];
 185         Object cb = sc.allocateCallback((Class)callbackType, allocateCallbackInstance(callbackType));
 186         return cb;
 187         //throw new UnsupportedOperationException("Hello " + callbackType);
 188     }
 189 
 190     static Object allocateCallbackInstance(Class<?> carrier) {
 191         return Proxy.newProxyInstance(carrier.getClassLoader(), new Class<?>[] { carrier },
 192                 (proxy, method, args) -> args.length > 0 ? args[0] : null);
 193     }
 194 
 195     static String filterFor(int k) {
 196         List<String> patterns = new ArrayList<>();
 197         for (int i = 0 ; i < MAX_CODE ; i++) {
 198             if (i != k) {
 199                 patterns.add("f" + i + "_");
 200             }
 201         }
 202         return String.format("(%s).*", String.join("|", patterns));
 203     }
 204 }