1 /* 2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @modules jdk.jextract 27 * @build TestUpcall 28 * 29 * @run testng/othervm -Djdk.internal.foreign.UpcallHandler.FASTPATH=none TestUpcall 30 * @run testng/othervm TestUpcall 31 */ 32 33 import org.testng.annotations.*; 34 35 import java.foreign.Libraries; 36 import java.foreign.NativeTypes; 37 import java.foreign.Scope; 38 import java.foreign.layout.Group; 39 import java.foreign.layout.Layout; 40 import java.foreign.layout.Padding; 41 import java.foreign.memory.Pointer; 42 import java.foreign.memory.Struct; 43 import java.lang.invoke.MethodHandles; 44 import java.lang.reflect.Method; 45 import java.lang.reflect.ParameterizedType; 46 import java.lang.reflect.Proxy; 47 import java.nio.file.Path; 48 import java.util.ArrayList; 49 import java.util.List; 50 import java.util.function.Consumer; 51 import java.util.stream.Stream; 52 53 import static org.testng.Assert.*; 54 55 public class TestUpcall extends JextractToolRunner { 56 57 final static int MAX_CODE = 20; 58 59 public static class UpcallTest { 60 61 private final Class<?> headerCls; 62 private final Object lib; 63 64 public UpcallTest(Class<?> headerCls, Object lib) { 65 this.headerCls = headerCls; 66 this.lib = lib; 67 } 68 69 @Test(dataProvider = "getArgs") 70 public void testUpCall(String mName, @NoInjection Method m) throws ReflectiveOperationException { 71 System.err.print("Calling " + mName + "..."); 72 try(Scope scope = Scope.newNativeScope()) { 73 List<Consumer<Object>> checks = new ArrayList<>(); 74 Object res = m.invoke(lib, makeArgs(scope, m, checks)); 75 if (m.getReturnType() != void.class) { 76 checks.forEach(c -> c.accept(res)); 77 } 78 } 79 System.err.println("...done"); 80 } 81 82 @DataProvider 83 public Object[][] getArgs() { 84 return Stream.of(headerCls.getDeclaredMethods()) 85 .map(m -> new Object[]{ m.getName(), m }) 86 .toArray(Object[][]::new); 87 } 88 89 } 90 91 @Factory 92 public Object[] getTests() throws ReflectiveOperationException { 93 List<UpcallTest> res = new ArrayList<>(); 94 for (int i = 0 ; i < MAX_CODE ; i++) { 95 Path clzPath = getOutputFilePath("libTestUpcall.jar"); 96 checkSuccess(null,"-o", clzPath.toString(), 97 "--exclude-symbols", filterFor(i), 98 getInputFilePath("libTestUpcall.h").toString()); 99 Class<?> headerCls = loadClass("libTestUpcall", clzPath); 100 Object lib = Libraries.bind(headerCls, Libraries.loadLibrary(MethodHandles.lookup(), "TestUpcall")); 101 res.add(new UpcallTest(headerCls, lib)); 102 } 103 if(res.isEmpty()) 104 throw new RuntimeException("Could not generate any tests"); 105 return res.toArray(); 106 } 107 108 static Object[] makeArgs(Scope sc, Method m, List<Consumer<Object>> checks) throws ReflectiveOperationException { 109 Class<?>[] params = m.getParameterTypes(); 110 Object[] args = new Object[params.length]; 111 for (int i = 0 ; i < params.length - 1 ; i++) { 112 args[i] = makeArg(sc, params[i], checks, i == 0); 113 } 114 args[params.length - 1] = makeCallback(sc, m); 115 return args; 116 } 117 118 @SuppressWarnings("unchecked") 119 static Object makeArg(Scope sc, Class<?> carrier, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException { 120 if (Struct.class.isAssignableFrom(carrier)) { 121 Struct<?> str = sc.allocateStruct((Class)carrier); 122 initStruct(sc, str, checks, check); 123 return str; 124 } else if (carrier == int.class) { 125 if (check) { 126 checks.add(o -> assertEquals(o, 42)); 127 } 128 return 42; 129 } else if (carrier == float.class) { 130 if (check) { 131 checks.add(o -> assertEquals(o, 12f)); 132 } 133 return 12f; 134 } else if (carrier == double.class) { 135 if (check) { 136 checks.add(o -> assertEquals(o, 24d)); 137 } 138 return 24d; 139 } else if (carrier == Pointer.class) { 140 Pointer<?> p = sc.allocate(NativeTypes.INT32); 141 if (check) { 142 checks.add(o -> { 143 try { 144 assertEquals(((Pointer<?>)o).addr(), p.addr()); 145 } catch (Throwable ex) { 146 throw new IllegalStateException(ex); 147 } 148 }); 149 } 150 return p; 151 } else { 152 throw new IllegalStateException("Unexpected carrier: " + carrier); 153 } 154 } 155 156 static void initStruct(Scope sc, Struct<?> str, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException { 157 Group g = (Group)str.ptr().type().layout(); 158 for (Layout l : g.elements()) { 159 if (l instanceof Padding) continue; 160 Method getter = str.getClass().getDeclaredMethod(l.annotations().get("get")); 161 Class<?> carrier = getter.getReturnType(); 162 Method setter = str.getClass().getDeclaredMethod(l.annotations().get("set"), carrier); 163 List<Consumer<Object>> fieldsCheck = new ArrayList<>(); 164 Object value = makeArg(sc, carrier, fieldsCheck, check); 165 //set value 166 setter.invoke(str, value); 167 //add check 168 if (check) { 169 assertTrue(fieldsCheck.size() == 1); 170 checks.add(o -> { 171 try { 172 fieldsCheck.get(0).accept(getter.invoke(o)); 173 } catch (Throwable ex) { 174 throw new IllegalStateException(ex); 175 } 176 }); 177 } 178 } 179 } 180 181 @SuppressWarnings("unchecked") 182 static Object makeCallback(Scope sc, Method m) { 183 ParameterizedType callbackParam = ((ParameterizedType)m.getGenericParameterTypes()[m.getParameterCount() - 1]); 184 Class<?> callbackType = (Class<?>)callbackParam.getActualTypeArguments()[0]; 185 Object cb = sc.allocateCallback((Class)callbackType, allocateCallbackInstance(callbackType)); 186 return cb; 187 //throw new UnsupportedOperationException("Hello " + callbackType); 188 } 189 190 static Object allocateCallbackInstance(Class<?> carrier) { 191 return Proxy.newProxyInstance(carrier.getClassLoader(), new Class<?>[] { carrier }, 192 (proxy, method, args) -> args.length > 0 ? args[0] : null); 193 } 194 195 static String filterFor(int k) { 196 List<String> patterns = new ArrayList<>(); 197 for (int i = 0 ; i < MAX_CODE ; i++) { 198 if (i != k) { 199 patterns.add("f" + i + "_"); 200 } 201 } 202 return String.format("(%s).*", String.join("|", patterns)); 203 } 204 }