1 /*
   2  * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8004970
  27  * @summary Lambda serialization in the presence of class loaders
  28  * @library /lib/testlibrary
  29  * @build jdk.testlibrary.IOUtils
  30  * @run main LambdaClassLoaderSerialization
  31  * @author Peter Levart
  32  */
  33 
  34 import java.io.ByteArrayInputStream;
  35 import java.io.ByteArrayOutputStream;
  36 import java.io.IOException;
  37 import java.io.InputStream;
  38 import java.io.ObjectInputStream;
  39 import java.io.ObjectOutputStream;
  40 import java.io.Serializable;
  41 
  42 public class LambdaClassLoaderSerialization {
  43 
  44     public interface SerializableRunnable extends Runnable, Serializable {}
  45 
  46     public static class MyCode implements SerializableRunnable {
  47 
  48         private byte[] serialize(Object o) {
  49             ByteArrayOutputStream baos;
  50             try (
  51                 ObjectOutputStream oos =
  52                     new ObjectOutputStream(baos = new ByteArrayOutputStream())
  53             ) {
  54                 oos.writeObject(o);
  55             }
  56             catch (IOException e) {
  57                 throw new RuntimeException(e);
  58             }
  59             return baos.toByteArray();
  60         }
  61 
  62         private <T> T deserialize(byte[] bytes) {
  63             try (
  64                 ObjectInputStream ois =
  65                     new ObjectInputStream(new ByteArrayInputStream(bytes))
  66             ) {
  67                 return (T) ois.readObject();
  68             }
  69             catch (IOException | ClassNotFoundException e) {
  70                 throw new RuntimeException(e);
  71             }
  72         }
  73 
  74         @Override
  75         public void run() {
  76             System.out.println("                this: " + this);
  77 
  78             SerializableRunnable deSerializedThis = deserialize(serialize(this));
  79             System.out.println("    deSerializedThis: " + deSerializedThis);
  80 
  81             SerializableRunnable runnable = () -> {System.out.println("HELLO");};
  82             System.out.println("            runnable: " + runnable);
  83 
  84             SerializableRunnable deSerializedRunnable = deserialize(serialize(runnable));
  85             System.out.println("deSerializedRunnable: " + deSerializedRunnable);
  86         }
  87     }
  88 
  89     public static void main(String[] args) throws Exception {
  90         ClassLoader myCl = new MyClassLoader(
  91             LambdaClassLoaderSerialization.class.getClassLoader()
  92         );
  93         Class<?> myCodeClass = Class.forName(
  94             LambdaClassLoaderSerialization.class.getName() + "$MyCode",
  95             true,
  96             myCl
  97         );
  98         Runnable myCode = (Runnable) myCodeClass.newInstance();
  99         myCode.run();
 100     }
 101 
 102     static class MyClassLoader extends ClassLoader {
 103         MyClassLoader(ClassLoader parent) {
 104             super(parent);
 105         }
 106 
 107         @Override
 108         protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
 109             if (name.indexOf('.') < 0) {
 110                 synchronized (getClassLoadingLock(name)) {
 111                     Class<?> c = findLoadedClass(name);
 112                     if (c == null) {
 113                         c = findClass(name);
 114                     }
 115                     if (resolve) {
 116                         resolveClass(c);
 117                     }
 118                     return c;
 119                 }
 120             } else {
 121                 return super.loadClass(name, resolve);
 122             }
 123         }
 124 
 125         @Override
 126         protected Class<?> findClass(String name) throws ClassNotFoundException {
 127             String path = name.replace('.', '/').concat(".class");
 128             try (InputStream is = getResourceAsStream(path)) {
 129                 if (is != null) {
 130                     byte[] bytes = is.readAllBytes();
 131                     return defineClass(name, bytes, 0, bytes.length);
 132                 } else {
 133                     throw new ClassNotFoundException(name);
 134                 }
 135             }
 136             catch (IOException e) {
 137                 throw new ClassNotFoundException(name, e);
 138             }
 139         }
 140     }
 141 }