< prev index next >

src/java.base/share/classes/java/lang/invoke/GenerateJLIClassesHelper.java

Print this page

        

@@ -27,22 +27,369 @@
 
 import jdk.internal.org.objectweb.asm.ClassWriter;
 import jdk.internal.org.objectweb.asm.Opcodes;
 import sun.invoke.util.Wrapper;
 
+import java.util.Arrays;
 import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.Map;
+import java.util.HashMap;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.stream.Stream;
 
 import static java.lang.invoke.MethodTypeForm.LF_INVINTERFACE;
 import static java.lang.invoke.MethodTypeForm.LF_INVVIRTUAL;
 
 /**
  * Helper class to assist the GenerateJLIClassesPlugin to get access to
  * generate classes ahead of time.
  */
 class GenerateJLIClassesHelper {
+    private static int DMH_INVOKE_VIRTUAL_TYPE = 0;
+    private static int DMH_INVOKE_INTERFACE_TYPE = 4;
+
+    private static final String DIRECT_HOLDER = "java/lang/invoke/DirectMethodHandle$Holder";
+    private static final String DMH_INVOKE_VIRTUAL = "invokeVirtual";
+    private static final String DMH_INVOKE_STATIC = "invokeStatic";
+    private static final String DMH_INVOKE_SPECIAL = "invokeSpecial";
+    private static final String DMH_NEW_INVOKE_SPECIAL = "newInvokeSpecial";
+    private static final String DMH_INVOKE_INTERFACE = "invokeInterface";
+    private static final String DMH_INVOKE_STATIC_INIT = "invokeStaticInit";
+    private static final String DMH_INVOKE_SPECIAL_IFC = "invokeSpecialIFC";
+
+    private static final String DELEGATING_HOLDER = "java/lang/invoke/DelegatingMethodHandle$Holder";
+    private static final String BASIC_FORMS_HOLDER = "java/lang/invoke/LambdaForm$Holder";
+
+    private static final String INVOKERS_HOLDER_NAME = "java.lang.invoke.Invokers$Holder";
+    private static final String INVOKERS_HOLDER_INTERNAL_NAME = INVOKERS_HOLDER_NAME.replace('.', '/');
+
+    // Map from DirectMethodHandle method type to internal ID, matching values
+    // of the corresponding constants in java.lang.invoke.MethodTypeForm
+    private static final Map<String, Integer> DMH_METHOD_TYPE_MAP =
+            Map.of(
+                DMH_INVOKE_VIRTUAL,     DMH_INVOKE_VIRTUAL_TYPE,
+                DMH_INVOKE_STATIC,      1,
+                DMH_INVOKE_SPECIAL,     2,
+                DMH_NEW_INVOKE_SPECIAL, 3,
+                DMH_INVOKE_INTERFACE,   DMH_INVOKE_INTERFACE_TYPE,
+                DMH_INVOKE_STATIC_INIT, 5,
+                DMH_INVOKE_SPECIAL_IFC, 20
+            );
+
+    /**
+     * Output to DumpLoadedClassList, format is simimar to LF_RESOLVE
+     * @see InvokerBytecodeGenerator
+     * @param line the line to output.
+     */
+    static native void cdsTraceResolve(String line);
+
+    private static TreeSet<String> getSpeciesTyes() { return speciesTypes; }
+    private static TreeSet<String> speciesTypes = new TreeSet<>();
+    private static TreeSet<String> invokerTypes = new TreeSet<>();
+    private static TreeSet<String> callSiteTypes = new TreeSet<>();
+    private static Map<String, Set<String>> dmhMethods = new TreeMap<>();
+
+    private static void clear() {
+        speciesTypes.clear();
+        invokerTypes.clear();
+        callSiteTypes.clear();
+        dmhMethods.clear();
+    }
+
+    private static void addSpeciesType(String type) {
+        speciesTypes.add(expandSignature(type));
+    }
+
+    private static void addInvokerType(String methodType) {
+        validateMethodType(methodType);
+        invokerTypes.add(methodType);
+    }
+
+    private static void addCallSiteType(String csType) {
+        validateMethodType(csType);
+        callSiteTypes.add(csType);
+    }
+
+    private static void readTraceConfig(Stream<String> lines) {
+        lines.map(line -> line.split(" "))
+            .forEach(parts -> {
+                switch (parts[0]) {
+                    case "[SPECIES_RESOLVE]":
+                        // Allow for new types of species data classes being resolved here
+                        if (parts.length == 3 && parts[1].startsWith("java.lang.invoke.BoundMethodHandle$Species_")) {
+                            String species = parts[1].substring("java.lang.invoke.BoundMethodHandle$Species_".length());
+                            if (!"L".equals(species)) {
+                                addSpeciesType(species);
+                            }
+                        }
+                        break;
+                    case "[LF_RESOLVE]":
+                        String methodType = parts[3];
+                        if (parts[1].equals(INVOKERS_HOLDER_NAME)) {
+                            if ("linkToTargetMethod".equals(parts[2]) ||
+                                    "linkToCallSite".equals(parts[2])) {
+                                addCallSiteType(methodType);
+                            } else {
+                                addInvokerType(methodType);
+                            }
+                        } else if (parts[1].contains("DirectMethodHandle")) {
+                            String dmh = parts[2];
+                            // ignore getObject etc for now (generated
+                            // by default)
+                            if (DMH_METHOD_TYPE_MAP.containsKey(dmh)) {
+                                addDMHMethodType(dmh, methodType);
+                            }
+                        }
+                        break;
+                    default: break; // ignore
+                }
+            });
+    }
+
+    /**
+     * called from vm to generate MethodHandle holder classes
+     * @return @code { Object[] } if holder classes can be generated.
+     * @param lines the output lines from @code { cdsTraceResolve }
+     */
+    static Object[] generateMethodHandleHolderClasses(String[] lines) {
+        try {
+            Map<String, byte[]> result = generateMHHolderClasses(lines);
+            clear();
+            if (result == null) {
+                return null;
+            }
+            int size = result.size();
+            Object[] ret_array = new Object[size * 2];
+            int index = 0;
+            for (Map.Entry<String, byte[]> entry : result.entrySet()) {
+                ret_array[index++] = entry.getKey();
+                ret_array[index++] = entry.getValue();
+            };
+            return ret_array;
+        } catch (Exception e) {
+            return null;
+        }
+    }
+
+    /* return a map of <class with module pkg, class bytes> */
+    static Map<String, byte[]> generateMHHolderClasses(String[] lines) {
+        if (lines == null || lines.length == 0) {
+            return null;
+        }
+        readTraceConfig(Arrays.stream(lines));
+        int count = 0;
+        for (Set<String> entry : dmhMethods.values()) {
+            count += entry.size();
+        }
+        MethodType[] directMethodTypes = new MethodType[count];
+        int[] dmhTypes = new int[count];
+        int index = 0;
+        for (Map.Entry<String, Set<String>> entry : dmhMethods.entrySet()) {
+            String dmhType = entry.getKey();
+            for (String type : entry.getValue()) {
+                // The DMH type to actually ask for is retrieved by removing
+                // the first argument, which needs to be of Object.class
+                MethodType mt = asMethodType(type);
+                if (mt.parameterCount() < 1 ||
+                    mt.parameterType(0) != Object.class) {
+                    throw new RuntimeException(
+                              "DMH type parameter must start with L: " + dmhType + " " + type);
+                }
+
+                // Adapt the method type of the LF to retrieve
+                directMethodTypes[index] = mt.dropParameterTypes(0, 1);
+
+                // invokeVirtual and invokeInterface must have a leading Object
+                // parameter, i.e., the receiver
+                dmhTypes[index] = DMH_METHOD_TYPE_MAP.get(dmhType);
+                if (dmhTypes[index] == DMH_INVOKE_INTERFACE_TYPE ||
+                    dmhTypes[index] == DMH_INVOKE_VIRTUAL_TYPE) {
+                    if (mt.parameterCount() < 2 ||
+                        mt.parameterType(1) != Object.class) {
+                        throw new RuntimeException(
+                                  "DMH type parameter must start with LL: " + dmhType + " " + type);
+                    }
+                }
+                index++;
+            }
+        }
+
+        // The invoker type to ask for is retrieved by removing the first
+        // and the last argument, which needs to be of Object.class
+        MethodType[] invokerMethodTypes = new MethodType[invokerTypes.size()];
+        index = 0;
+        for (String invokerType : invokerTypes) {
+            MethodType mt = asMethodType(invokerType);
+            final int lastParam = mt.parameterCount() - 1;
+            if (mt.parameterCount() < 2 ||
+                    mt.parameterType(0) != Object.class ||
+                    mt.parameterType(lastParam) != Object.class) {
+                throw new RuntimeException(
+                        "Invoker type parameter must start and end with Object: " + invokerType);
+            }
+            mt = mt.dropParameterTypes(lastParam, lastParam + 1);
+            invokerMethodTypes[index] = mt.dropParameterTypes(0, 1);
+            index++;
+        }
+
+        // The callSite type to ask for is retrieved by removing the last
+        // argument, which needs to be of Object.class
+        MethodType[] callSiteMethodTypes = new MethodType[callSiteTypes.size()];
+        index = 0;
+        for (String callSiteType : callSiteTypes) {
+            MethodType mt = asMethodType(callSiteType);
+            final int lastParam = mt.parameterCount() - 1;
+            if (mt.parameterCount() < 1 ||
+                    mt.parameterType(lastParam) != Object.class) {
+                throw new RuntimeException(
+                        "CallSite type parameter must end with Object: " + callSiteType);
+            }
+            callSiteMethodTypes[index] = mt.dropParameterTypes(lastParam, lastParam + 1);
+            index++;
+        }
+        Map<String, byte[]> result = new HashMap<String, byte[]>();
+
+        byte[] res = generateDirectMethodHandleHolderClassBytes(
+                         DIRECT_HOLDER, directMethodTypes, dmhTypes);
+        result.put(DIRECT_METHOD_HOLDER_ENTRY, res);
+
+        res = generateDelegatingMethodHandleHolderClassBytes(
+                  DELEGATING_HOLDER, directMethodTypes);
+        result.put(DELEGATING_METHOD_HOLDER_ENTRY, res);
+
+        res = generateInvokersHolderClassBytes(INVOKERS_HOLDER_INTERNAL_NAME,
+                  invokerMethodTypes, callSiteMethodTypes);
+        result.put(INVOKERS_HOLDER_ENTRY, res);
+
+        res  = generateBasicFormsClassBytes(BASIC_FORMS_HOLDER);
+        result.put(BASIC_FORMS_HOLDER_ENTRY, res);
+
+        speciesTypes.forEach(types -> {
+            Map.Entry<String, byte[]> entry = generateConcreteBMHClassBytes(types);
+            String className = entry.getKey();
+            String key = "/java.base/" + className + ".class";
+            byte[] value = entry.getValue();
+            result.put(key, value);
+        });
+
+        return result;
+    }
+
+    private static final String DIRECT_METHOD_HOLDER_ENTRY =
+            "/java.base/" + DIRECT_HOLDER + ".class";
+    private static final String DELEGATING_METHOD_HOLDER_ENTRY =
+            "/java.base/" + DELEGATING_HOLDER + ".class";
+    private static final String BASIC_FORMS_HOLDER_ENTRY =
+            "/java.base/" + BASIC_FORMS_HOLDER + ".class";
+    private static final String INVOKERS_HOLDER_ENTRY =
+            "/java.base/" + INVOKERS_HOLDER_INTERNAL_NAME + ".class";
+
+    private static MethodType asMethodType(String basicSignatureString) {
+        String[] parts = basicSignatureString.split("_");
+        assert(parts.length == 2);
+        assert(parts[1].length() == 1);
+        String parameters = expandSignature(parts[0]);
+        Class<?> rtype = simpleType(parts[1].charAt(0));
+        if (parameters.isEmpty()) {
+            return MethodType.methodType(rtype);
+        } else {
+            Class<?>[] ptypes = new Class<?>[parameters.length()];
+            for (int i = 0; i < ptypes.length; i++) {
+                ptypes[i] = simpleType(parameters.charAt(i));
+            }
+            return MethodType.methodType(rtype, ptypes);
+        }
+    }
+
+    private static void addDMHMethodType(String dmh, String methodType) {
+        validateMethodType(methodType);
+        Set<String> methodTypes = dmhMethods.get(dmh);
+        if (methodTypes == null) {
+            methodTypes = new TreeSet<>();
+            dmhMethods.put(dmh, methodTypes);
+        }
+        methodTypes.add(methodType);
+    }
+
+    private static void validateMethodType(String type) {
+        String[] typeParts = type.split("_");
+        // check return type (second part)
+        if (typeParts.length != 2 || typeParts[1].length() != 1
+                || "LJIFDV".indexOf(typeParts[1].charAt(0)) == -1) {
+            throw new RuntimeException(
+                    "Method type signature must be of form [LJIFD]*_[LJIFDV]");
+        }
+        // expand and check arguments (first part)
+        expandSignature(typeParts[0]);
+    }
+
+    // Convert LL -> LL, L3 -> LLL
+    private static String expandSignature(String signature) {
+        StringBuilder sb = new StringBuilder();
+        char last = 'X';
+        int count = 0;
+        for (int i = 0; i < signature.length(); i++) {
+            char c = signature.charAt(i);
+            if (c >= '0' && c <= '9') {
+                count *= 10;
+                count += (c - '0');
+            } else {
+                requireBasicType(c);
+                for (int j = 1; j < count; j++) {
+                    sb.append(last);
+                }
+                sb.append(c);
+                last = c;
+                count = 0;
+            }
+        }
+
+        // ended with a number, e.g., "L2": append last char count - 1 times
+        if (count > 1) {
+            requireBasicType(last);
+            for (int j = 1; j < count; j++) {
+                sb.append(last);
+            }
+        }
+        return sb.toString();
+    }
+
+    private static void requireBasicType(char c) {
+        if ("LIJFD".indexOf(c) < 0) {
+            throw new RuntimeException(
+                    "Character " + c + " must correspond to a basic field type: LIJFD");
+        }
+    }
+
+    private static Class<?> simpleType(char c) {
+        switch (c) {
+            case 'F':
+                return float.class;
+            case 'D':
+                return double.class;
+            case 'I':
+                return int.class;
+            case 'L':
+                return Object.class;
+            case 'J':
+                return long.class;
+            case 'V':
+                return void.class;
+            case 'Z':
+            case 'B':
+            case 'S':
+            case 'C':
+                throw new IllegalArgumentException("Not a valid primitive: " + c +
+                        " (use I instead)");
+            default:
+                throw new IllegalArgumentException("Not a primitive: " + c);
+        }
+    }
+
 
     static byte[] generateBasicFormsClassBytes(String className) {
         ArrayList<LambdaForm> forms = new ArrayList<>();
         ArrayList<String> names = new ArrayList<>();
         HashSet<String> dedupSet = new HashSet<>();
< prev index next >