1 /*
   2  * Copyright (c) 2014, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  *
  23  */
  24 
  25 import java.lang.instrument.ClassDefinition;
  26 import java.lang.instrument.Instrumentation;
  27 import java.lang.instrument.UnmodifiableClassException;
  28 import java.net.URL;
  29 import java.net.URLClassLoader;
  30 import java.io.File;
  31 import java.io.FileWriter;
  32 import java.security.CodeSigner;
  33 import java.security.CodeSource;
  34 import java.security.ProtectionDomain;
  35 import sun.hotspot.WhiteBox;
  36 
  37 public class InstrumentationApp {
  38     static WhiteBox wb = WhiteBox.getWhiteBox();
  39 
  40     public static final String COO_CLASS_NAME = "InstrumentationApp$Coo";
  41 
  42     public static interface Intf {            // Loaded from Boot class loader (-Xbootclasspath/a).
  43         public String get();
  44     }
  45     public static class Bar implements Intf { // Loaded from Boot class loader.
  46         public String get() {
  47             // The initial transform:
  48             //      change "buzz" -> "fuzz"
  49             // The re-transform:
  50             //      change "buzz" -> "guzz"
  51             return "buzz";
  52         }
  53     }
  54     public static class Foo implements Intf { // Loaded from AppClassLoader, or from a custom loader
  55         public String get() {
  56             // The initial transform:
  57             //      change "buzz" -> "fuzz"
  58             // The re-transform:
  59             //      change "buzz" -> "guzz"
  60             return "buzz";
  61         }
  62     }
  63     public static class Coo implements Intf { // Loaded from custom class loader.
  64         public String get() {
  65             // The initial transform:
  66             //      change "buzz" -> "fuzz"
  67             // The re-transform:
  68             //      change "buzz" -> "guzz"
  69             return "buzz";
  70         }
  71     }
  72 
  73     // This class file should be archived if AppCDSv2 is enabled on this platform. See
  74     // the comments around the call to TestCommon.dump in InstrumentationTest.java.
  75     public static class ArchivedIfAppCDSv2Enabled {}
  76 
  77     public static boolean isAppCDSV2Enabled() {
  78         return wb.isSharedClass(ArchivedIfAppCDSv2Enabled.class);
  79     }
  80 
  81     public static class MyLoader extends URLClassLoader {
  82         public MyLoader(URL[] urls, ClassLoader parent, File jar) {
  83             super(urls, parent);
  84             this.jar = jar;
  85         }
  86         File jar;
  87 
  88         @Override
  89         protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
  90             synchronized (getClassLoadingLock(name)) {
  91                 // First, check if the class has already been loaded
  92                 Class<?> clz = findLoadedClass(name);
  93                 if (clz != null) {
  94                     return clz;
  95                 }
  96 
  97                 if (name.equals(COO_CLASS_NAME)) {
  98                     try {
  99                         byte[] buff = Util.getClassFileFromJar(jar, name);
 100                         return defineClass(name, buff, 0, buff.length);
 101                     } catch (Throwable t) {
 102                         t.printStackTrace();
 103                         throw new RuntimeException("Unexpected", t);
 104                     }
 105                 }
 106             }
 107             return super.loadClass(name, resolve);
 108         }
 109     }
 110 
 111     static int numTests = 0;
 112     static int failed = 0;
 113     static boolean isAttachingAgent = false;
 114     static Instrumentation instrumentation;
 115 
 116     public static void main(String args[]) throws Throwable {
 117         System.out.println("INFO: AppCDSv1 " + (wb.isSharedClass(InstrumentationApp.class) ? "enabled" :"disabled"));
 118         System.out.println("INFO: AppCDSv2 " + (isAppCDSV2Enabled()                        ? "enabled" : "disabled"));
 119 
 120         String flagFile = args[0];
 121         File bootJar = new File(args[1]);
 122         File appJar  = new File(args[2]);
 123         File custJar = new File(args[3]);
 124         waitAttach(flagFile);
 125 
 126         instrumentation = InstrumentationRegisterClassFileTransformer.getInstrumentation();
 127         System.out.println("INFO: instrumentation = " + instrumentation);
 128 
 129         testBootstrapCDS("Bootstrap Loader", bootJar);
 130         testAppCDSv1("Application Loader", appJar);
 131 
 132         if (isAppCDSV2Enabled()) {
 133           testAppCDSv2("Custom Loader (unregistered)", custJar);
 134         }
 135 
 136         if (failed > 0) {
 137             throw new RuntimeException("FINAL RESULT: " + failed + " out of " + numTests + " test case(s) have failed");
 138         } else {
 139             System.out.println("FINAL RESULT: All " + numTests + " test case(s) have passed!");
 140         }
 141     }
 142 
 143     static void waitAttach(String flagFile) throws Throwable {
 144         // See InstrumentationTest.java for the hand-shake protocol.
 145         if (!flagFile.equals("noattach")) {
 146             File f = new File(flagFile);
 147             try (FileWriter fw = new FileWriter(f)) {
 148                 long pid = ProcessHandle.current().pid();
 149                 System.out.println("my pid = " + pid);
 150                 fw.write(Long.toString(pid));
 151                 fw.write("\n");
 152                 for (int i=0; i<10; i++) {
 153                   // Parent process waits until we have written more than 100 bytes, so it won't
 154                   // read a partial pid
 155                   fw.write("==========");
 156                 }
 157                 fw.close();
 158             }
 159 
 160             long start = System.currentTimeMillis();
 161             while (f.exists()) {
 162                 long elapsed = System.currentTimeMillis() - start;
 163                 System.out.println(".... (" + elapsed + ") waiting for deletion of " + f);
 164                 Thread.sleep(1000);
 165             }
 166             System.out.println("Attach succeeded (child)");
 167             isAttachingAgent = true;
 168         }
 169     }
 170 
 171     static void testBootstrapCDS(String group, File jar) throws Throwable {
 172         doTest(group, new Bar(), jar);
 173     }
 174 
 175     static void testAppCDSv1(String group, File jar) throws Throwable {
 176         doTest(group, new Foo(), jar);
 177     }
 178 
 179     static void testAppCDSv2(String group, File jar) throws Throwable {
 180         URL[] urls = new URL[] {jar.toURI().toURL()};
 181         MyLoader loader = new MyLoader(urls, InstrumentationApp.class.getClassLoader(), jar);
 182         Class klass = loader.loadClass(COO_CLASS_NAME);
 183         doTest(group, (Intf)klass.newInstance(), jar);
 184     }
 185 
 186     static void doTest(String group, Intf object, File jar) throws Throwable {
 187         Class klass = object.getClass();
 188         System.out.println();
 189         System.out.println("++++++++++++++++++++++++++");
 190         System.out.println("Test group: " + group);
 191         System.out.println("Testing with classloader = " + klass.getClassLoader());
 192         System.out.println("Testing with class       = " + klass);
 193         System.out.println("++++++++++++++++++++++++++");
 194 
 195         // Initial transform
 196         String f = object.get();
 197         assertTrue(f.equals("fuzz"), "object.get(): Initial transform should give 'fuzz'", f);
 198 
 199         // Retransform
 200         f = "(failed)";
 201         try {
 202             instrumentation.retransformClasses(klass);
 203             f = object.get();
 204         } catch (UnmodifiableClassException|UnsupportedOperationException e) {
 205             e.printStackTrace();
 206         }
 207         assertTrue(f.equals("guzz"), "object.get(): retransformation should give 'guzz'", f);
 208 
 209         // Redefine
 210         byte[] buff = Util.getClassFileFromJar(jar, klass.getName());
 211         Util.replace(buff, "buzz", "huzz");
 212         f = "(failed)";
 213         try {
 214             instrumentation.redefineClasses(new ClassDefinition(klass, buff));
 215             f = object.get();
 216         } catch (UnmodifiableClassException|UnsupportedOperationException e) {
 217             e.printStackTrace();
 218         }
 219         assertTrue(f.equals("quzz"), "object.get(): redefinition should give 'quzz'", f);
 220 
 221         System.out.println("++++++++++++++++++++++++++++++++++++++++++++++++ (done)\n\n");
 222     }
 223 
 224     private static void assertTrue(boolean expr, String msg, String string) {
 225         numTests ++;
 226         System.out.printf("Test case %2d ", numTests);
 227 
 228         if (expr) {
 229             System.out.println("PASSED: " + msg + " and we got '" + string + "'");
 230         } else {
 231             failed ++;
 232             System.out.println("FAILED: " + msg + " but we got '" + string + "'");
 233         }
 234     }
 235 }