1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @modules jdk.jextract
  27  * @build TestDowncall
  28  *
  29  * @run testng/othervm -Djdk.internal.foreign.NativeInvoker.FASTPATH=none TestDowncall
  30  * @run testng/othervm TestDowncall
  31  */
  32 
  33 import org.testng.annotations.*;
  34 import static org.testng.Assert.*;
  35 
  36 import java.foreign.Libraries;
  37 import java.foreign.NativeTypes;
  38 import java.foreign.Scope;
  39 import java.foreign.layout.Group;
  40 import java.foreign.layout.Layout;
  41 import java.foreign.layout.Padding;
  42 import java.foreign.memory.Pointer;
  43 import java.foreign.memory.Struct;
  44 import java.lang.invoke.MethodHandles;
  45 import java.lang.reflect.Method;
  46 import java.nio.file.Path;
  47 import java.util.ArrayList;
  48 import java.util.List;
  49 import java.util.function.Consumer;
  50 import java.util.stream.Stream;
  51 
  52 public class TestDowncall extends JextractToolRunner {
  53 
  54     final static int MAX_CODE = 20;
  55 
  56     public static class DowncallTest {
  57 
  58         private final Class<?> headerCls;
  59         private final Object lib;
  60 
  61         public DowncallTest(Class<?> headerCls, Object lib) {
  62             this.headerCls = headerCls;
  63             this.lib = lib;
  64         }
  65 
  66         @Test(dataProvider = "getArgs")
  67         public void testDownCall(String mName, @NoInjection Method m)  throws ReflectiveOperationException {
  68             System.err.print("Calling " + mName + "...");
  69             try(Scope scope = Scope.newNativeScope()) {
  70                 List<Consumer<Object>> checks = new ArrayList<>();
  71                 Object res = m.invoke(lib, makeArgs(scope, m, checks));
  72                 if (m.getReturnType() != void.class) {
  73                     checks.forEach(c -> c.accept(res));
  74                 }
  75             }
  76             System.err.println("...done");
  77         }
  78 
  79         @DataProvider
  80         public Object[][] getArgs() {
  81             return Stream.of(headerCls.getDeclaredMethods())
  82                     .map(m -> new Object[]{ m.getName(), m })
  83                     .toArray(Object[][]::new);
  84         }
  85 
  86     }
  87 
  88     @Factory
  89     public Object[] getTests() throws ReflectiveOperationException {
  90         List<DowncallTest> res = new ArrayList<>();
  91         for (int i = 0 ; i < MAX_CODE ; i++) {
  92             Path clzPath = getOutputFilePath("libTestDowncall.jar");
  93             run("-o", clzPath.toString(),
  94                     "--exclude-symbols", filterFor(i),
  95                     getInputFilePath("libTestDowncall.h").toString()).checkSuccess();
  96             Class<?> headerCls = loadClass("libTestDowncall", clzPath);
  97             Object lib = Libraries.bind(headerCls, Libraries.loadLibrary(MethodHandles.lookup(), "TestDowncall"));
  98             res.add(new DowncallTest(headerCls, lib));
  99         }
 100         if(res.isEmpty())
 101             throw new RuntimeException("Could not generate any tests");
 102         return res.toArray();
 103     }
 104 
 105     static Object[] makeArgs(Scope sc, Method m, List<Consumer<Object>> checks) throws ReflectiveOperationException {
 106         Class<?>[] params = m.getParameterTypes();
 107         Object[] args = new Object[params.length];
 108         for (int i = 0 ; i < params.length ; i++) {
 109             args[i] = makeArg(sc, params[i], checks, i == 0);
 110         }
 111         return args;
 112     }
 113 
 114     @SuppressWarnings("unchecked")
 115     static Object makeArg(Scope sc, Class<?> carrier, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
 116         if (Struct.class.isAssignableFrom(carrier)) {
 117             Struct<?> str = sc.allocateStruct((Class)carrier);
 118             initStruct(sc, str, checks, check);
 119             return str;
 120         } else if (carrier == int.class) {
 121             if (check) {
 122                 checks.add(o -> assertEquals(o, 42));
 123             }
 124             return 42;
 125         } else if (carrier == float.class) {
 126             if (check) {
 127                 checks.add(o -> assertEquals(o, 12f));
 128             }
 129             return 12f;
 130         } else if (carrier == double.class) {
 131             if (check) {
 132                 checks.add(o -> assertEquals(o, 24d));
 133             }
 134             return 24d;
 135         } else if (carrier == Pointer.class) {
 136             Pointer<?> p = sc.allocate(NativeTypes.INT32);
 137             if (check) {
 138                 checks.add(o -> {
 139                     try {
 140                         assertEquals(((Pointer<?>)o).addr(), p.addr());
 141                     } catch (Throwable ex) {
 142                         throw new IllegalStateException(ex);
 143                     }
 144                 });
 145             }
 146             return p;
 147         } else {
 148             throw new IllegalStateException("Unexpected carrier: " + carrier);
 149         }
 150     }
 151 
 152     static void initStruct(Scope sc, Struct<?> str, List<Consumer<Object>> checks, boolean check) throws ReflectiveOperationException {
 153         Group g = (Group)str.ptr().type().layout();
 154         for (Layout l : g.elements()) {
 155             if (l instanceof Padding) continue;
 156             Method getter = str.getClass().getDeclaredMethod(l.annotations().get("get"));
 157             Class<?> carrier = getter.getReturnType();
 158             Method setter = str.getClass().getDeclaredMethod(l.annotations().get("set"), carrier);
 159             List<Consumer<Object>> fieldsCheck = new ArrayList<>();
 160             Object value = makeArg(sc, carrier, fieldsCheck, check);
 161             //set value
 162             setter.invoke(str, value);
 163             //add check
 164             if (check) {
 165                 assertTrue(fieldsCheck.size() == 1);
 166                 checks.add(o -> {
 167                     try {
 168                         fieldsCheck.get(0).accept(getter.invoke(o));
 169                     } catch (Throwable ex) {
 170                         throw new IllegalStateException(ex);
 171                     }
 172                 });
 173             }
 174         }
 175     }
 176 
 177     static String filterFor(int k) {
 178         List<String> patterns = new ArrayList<>();
 179         for (int i = 0 ; i < MAX_CODE ; i++) {
 180             if (i != k) {
 181                 patterns.add("f" + i + "_");
 182             }
 183         }
 184         return String.format("(%s).*", String.join("|", patterns));
 185     }
 186 }
--- EOF ---