1 /*
   2  * Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8071474
  27  * @summary Better failure atomicity for default read object.
  28  * @modules jdk.compiler
  29  * @library /lib/testlibrary
  30  * @build jdk.testlibrary.FileUtils
  31  * @compile FailureAtomicity.java SerialRef.java
  32  * @run main failureAtomicity.FailureAtomicity
  33  */
  34 
  35 package failureAtomicity;
  36 
  37 import java.io.ByteArrayInputStream;
  38 import java.io.ByteArrayOutputStream;
  39 import java.io.File;
  40 import java.io.IOException;
  41 import java.io.InputStream;
  42 import java.io.ObjectInputStream;
  43 import java.io.ObjectOutputStream;
  44 import java.io.ObjectStreamClass;
  45 import java.io.UncheckedIOException;
  46 import java.lang.reflect.Constructor;
  47 import java.net.URL;
  48 import java.net.URLClassLoader;
  49 import java.nio.file.Files;
  50 import java.nio.file.Path;
  51 import java.nio.file.Paths;
  52 import java.util.ArrayList;
  53 import java.util.Arrays;
  54 import java.util.List;
  55 import java.util.function.BiConsumer;
  56 import java.util.stream.Collectors;
  57 import javax.tools.JavaCompiler;
  58 import javax.tools.JavaFileObject;
  59 import javax.tools.StandardJavaFileManager;
  60 import javax.tools.StandardLocation;
  61 import javax.tools.ToolProvider;
  62 import jdk.testlibrary.FileUtils;
  63 
  64 @SuppressWarnings("unchecked")
  65 public class FailureAtomicity {
  66     static final Path TEST_SRC = Paths.get(System.getProperty("test.src", "."));
  67     static final Path TEST_CLASSES = Paths.get(System.getProperty("test.classes", "."));
  68     static final Path fooTemplate = TEST_SRC.resolve("Foo.template");
  69     static final Path barTemplate = TEST_SRC.resolve("Bar.template");
  70 
  71     static final String[] PKGS = { "a.b.c", "x.y.z" };
  72 
  73     public static void main(String[] args) throws Exception {
  74         test_Foo();
  75         test_BadFoo();  // 'Bad' => incompatible type; cannot be "fully" deserialized
  76         test_FooWithReadObject();
  77         test_BadFooWithReadObject();
  78 
  79         test_Foo_Bar();
  80         test_Foo_BadBar();
  81         test_BadFoo_Bar();
  82         test_BadFoo_BadBar();
  83         test_Foo_BarWithReadObject();
  84         test_Foo_BadBarWithReadObject();
  85         test_BadFoo_BarWithReadObject();
  86         test_BadFoo_BadBarWithReadObject();
  87         test_FooWithReadObject_Bar();
  88         test_FooWithReadObject_BadBar();
  89         test_BadFooWithReadObject_Bar();
  90         test_BadFooWithReadObject_BadBar();
  91     }
  92 
  93     static final BiConsumer<Object,Object> FOO_FIELDS_EQUAL = (a,b) -> {
  94         try {
  95             int aPrim = a.getClass().getField("fooPrim").getInt(a);
  96             int bPrim = b.getClass().getField("fooPrim").getInt(b);
  97             if (aPrim != bPrim)
  98                 throw new AssertionError("Not equal: (" + aPrim + "!=" + bPrim
  99                                          + "), in [" + a + "] [" + b + "]");
 100             Object aRef = a.getClass().getField("fooRef").get(a);
 101             Object bRef = b.getClass().getField("fooRef").get(b);
 102             if (!aRef.equals(bRef))
 103                 throw new RuntimeException("Not equal: (" + aRef + "!=" + bRef
 104                                            + "), in [" + a + "] [" + b + "]");
 105         } catch (NoSuchFieldException | IllegalAccessException x) {
 106             throw new InternalError(x);
 107         }
 108     };
 109     static final BiConsumer<Object,Object> FOO_FIELDS_DEFAULT = (ignore,b) -> {
 110         try {
 111             int aPrim = b.getClass().getField("fooPrim").getInt(b);
 112             if (aPrim != 0)
 113                 throw new AssertionError("Expected 0, got:" + aPrim
 114                                          + ", in [" + b + "]");
 115             Object aRef = b.getClass().getField("fooRef").get(b);
 116             if (aRef != null)
 117                 throw new RuntimeException("Expected null, got:" + aRef
 118                                            + ", in [" + b + "]");
 119         } catch (NoSuchFieldException | IllegalAccessException x) {
 120             throw new InternalError(x);
 121         }
 122     };
 123     static final BiConsumer<Object,Object> BAR_FIELDS_EQUAL = (a,b) -> {
 124         try {
 125             long aPrim = a.getClass().getField("barPrim").getLong(a);
 126             long bPrim = b.getClass().getField("barPrim").getLong(b);
 127             if (aPrim != bPrim)
 128                 throw new AssertionError("Not equal: (" + aPrim + "!=" + bPrim
 129                                          + "), in [" + a + "] [" + b + "]");
 130             Object aRef = a.getClass().getField("barRef").get(a);
 131             Object bRef = b.getClass().getField("barRef").get(b);
 132             if (!aRef.equals(bRef))
 133                 throw new RuntimeException("Not equal: (" + aRef + "!=" + bRef
 134                                            + "), in [" + a + "] [" + b + "]");
 135         } catch (NoSuchFieldException | IllegalAccessException x) {
 136             throw new InternalError(x);
 137         }
 138     };
 139     static final BiConsumer<Object,Object> BAR_FIELDS_DEFAULT = (ignore,b) -> {
 140         try {
 141             long aPrim = b.getClass().getField("barPrim").getLong(b);
 142             if (aPrim != 0L)
 143                 throw new AssertionError("Expected 0, got:" + aPrim
 144                                          + ", in [" + b + "]");
 145             Object aRef = b.getClass().getField("barRef").get(b);
 146             if (aRef != null)
 147                 throw new RuntimeException("Expected null, got:" + aRef
 148                                            + ", in [" + b + "]");
 149         } catch (NoSuchFieldException | IllegalAccessException x) {
 150             throw new InternalError(x);
 151         }
 152     };
 153 
 154     static void test_Foo() {
 155         testFoo("Foo", "String", false, false, FOO_FIELDS_EQUAL); }
 156     static void test_BadFoo() {
 157         testFoo("BadFoo", "byte[]", true, false, FOO_FIELDS_DEFAULT); }
 158     static void test_FooWithReadObject() {
 159         testFoo("FooWithReadObject", "String", false, true, FOO_FIELDS_EQUAL); }
 160     static void test_BadFooWithReadObject() {
 161         testFoo("BadFooWithReadObject", "byte[]", true, true, FOO_FIELDS_DEFAULT); }
 162 
 163     static void testFoo(String testName, String xyzZebraType,
 164                         boolean expectCCE, boolean withReadObject,
 165                         BiConsumer<Object,Object>... resultCheckers) {
 166         System.out.println("\nTesting " + testName);
 167         try {
 168             Path testRoot = testDir(testName);
 169             Path srcRoot = Files.createDirectory(testRoot.resolve("src"));
 170             List<Path> srcFiles = new ArrayList<>();
 171             srcFiles.add(createSrc(PKGS[0], fooTemplate, srcRoot, "String", withReadObject));
 172             srcFiles.add(createSrc(PKGS[1], fooTemplate, srcRoot, xyzZebraType, withReadObject));
 173 
 174             Path build = Files.createDirectory(testRoot.resolve("build"));
 175             javac(build, srcFiles);
 176 
 177             URLClassLoader loader = new URLClassLoader(new URL[]{ build.toUri().toURL() },
 178                                                        FailureAtomicity.class.getClassLoader());
 179             Class<?> fooClass = Class.forName(PKGS[0] + ".Foo", true, loader);
 180             Constructor<?> ctr = fooClass.getConstructor(
 181                     new Class<?>[]{int.class, String.class, String.class});
 182             Object abcFoo = ctr.newInstance(5, "chegar", "zebra");
 183 
 184             try {
 185                 toOtherPkgInstance(abcFoo, loader);
 186                 if (expectCCE)
 187                     throw new AssertionError("Expected CCE not thrown");
 188             } catch (ClassCastException e) {
 189                 if (!expectCCE)
 190                     throw new AssertionError("UnExpected CCE: " + e);
 191             }
 192 
 193             Object deserialInstance = failureAtomicity.SerialRef.obj;
 194 
 195             System.out.println("abcFoo:           " + abcFoo);
 196             System.out.println("deserialInstance: " + deserialInstance);
 197 
 198             for (BiConsumer<Object, Object> rc : resultCheckers)
 199                 rc.accept(abcFoo, deserialInstance);
 200         } catch (IOException x) {
 201             throw new UncheckedIOException(x);
 202         } catch (ReflectiveOperationException x) {
 203             throw new InternalError(x);
 204         }
 205     }
 206 
 207     static void test_Foo_Bar() {
 208         testFooBar("Foo_Bar", "String", "String", false, false, false,
 209                    FOO_FIELDS_EQUAL, BAR_FIELDS_EQUAL);
 210     }
 211     static void test_Foo_BadBar() {
 212         testFooBar("Foo_BadBar", "String", "byte[]", true, false, false,
 213                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 214     }
 215     static void test_BadFoo_Bar() {
 216         testFooBar("BadFoo_Bar", "byte[]", "String", true, false, false,
 217                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 218     }
 219     static void test_BadFoo_BadBar() {
 220         testFooBar("BadFoo_BadBar", "byte[]", "byte[]", true, false, false,
 221                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 222     }
 223     static void test_Foo_BarWithReadObject() {
 224         testFooBar("Foo_BarWithReadObject", "String", "String", false, false, true,
 225                    FOO_FIELDS_EQUAL, BAR_FIELDS_EQUAL);
 226     }
 227     static void test_Foo_BadBarWithReadObject() {
 228         testFooBar("Foo_BadBarWithReadObject", "String", "byte[]", true, false, true,
 229                    FOO_FIELDS_EQUAL, BAR_FIELDS_DEFAULT);
 230     }
 231     static void test_BadFoo_BarWithReadObject() {
 232         testFooBar("BadFoo_BarWithReadObject", "byte[]", "String", true, false, true,
 233                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 234     }
 235     static void test_BadFoo_BadBarWithReadObject() {
 236         testFooBar("BadFoo_BadBarWithReadObject", "byte[]", "byte[]", true, false, true,
 237                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 238     }
 239 
 240     static void test_FooWithReadObject_Bar() {
 241         testFooBar("FooWithReadObject_Bar", "String", "String", false, true, false,
 242                    FOO_FIELDS_EQUAL, BAR_FIELDS_EQUAL);
 243     }
 244     static void test_FooWithReadObject_BadBar() {
 245         testFooBar("FooWithReadObject_BadBar", "String", "byte[]", true, true, false,
 246                    FOO_FIELDS_EQUAL, BAR_FIELDS_DEFAULT);
 247     }
 248     static void test_BadFooWithReadObject_Bar() {
 249         testFooBar("BadFooWithReadObject_Bar", "byte[]", "String", true, true, false,
 250                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 251     }
 252     static void test_BadFooWithReadObject_BadBar() {
 253         testFooBar("BadFooWithReadObject_BadBar", "byte[]", "byte[]", true, true, false,
 254                    FOO_FIELDS_DEFAULT, BAR_FIELDS_DEFAULT);
 255     }
 256 
 257     static void testFooBar(String testName, String xyzFooZebraType,
 258                            String xyzBarZebraType, boolean expectCCE,
 259                            boolean fooWithReadObject, boolean barWithReadObject,
 260                            BiConsumer<Object,Object>... resultCheckers) {
 261         System.out.println("\nTesting " + testName);
 262         try {
 263             Path testRoot = testDir(testName);
 264             Path srcRoot = Files.createDirectory(testRoot.resolve("src"));
 265             List<Path> srcFiles = new ArrayList<>();
 266             srcFiles.add(createSrc(PKGS[0], fooTemplate, srcRoot, "String",
 267                                    fooWithReadObject, "String"));
 268             srcFiles.add(createSrc(PKGS[1], fooTemplate, srcRoot, xyzFooZebraType,
 269                                    fooWithReadObject, xyzFooZebraType));
 270             srcFiles.add(createSrc(PKGS[0], barTemplate, srcRoot, "String",
 271                                    barWithReadObject, "String"));
 272             srcFiles.add(createSrc(PKGS[1], barTemplate, srcRoot, xyzBarZebraType,
 273                                    barWithReadObject, xyzFooZebraType));
 274 
 275             Path build = Files.createDirectory(testRoot.resolve("build"));
 276             javac(build, srcFiles);
 277 
 278             URLClassLoader loader = new URLClassLoader(new URL[]{ build.toUri().toURL() },
 279                                                        FailureAtomicity.class.getClassLoader());
 280             Class<?> fooClass = Class.forName(PKGS[0] + ".Bar", true, loader);
 281             Constructor<?> ctr = fooClass.getConstructor(
 282                     new Class<?>[]{int.class, String.class, String.class,
 283                                    long.class, String.class, String.class});
 284             Object abcBar = ctr.newInstance( 5, "chegar", "zebraFoo", 111L, "aBar", "zebraBar");
 285 
 286             try {
 287                 toOtherPkgInstance(abcBar, loader);
 288                 if (expectCCE)
 289                     throw new AssertionError("Expected CCE not thrown");
 290             } catch (ClassCastException e) {
 291                 if (!expectCCE)
 292                     throw new AssertionError("UnExpected CCE: " + e);
 293             }
 294 
 295             Object deserialInstance = failureAtomicity.SerialRef.obj;
 296 
 297             System.out.println("abcBar:           " + abcBar);
 298             System.out.println("deserialInstance: " + deserialInstance);
 299 
 300             for (BiConsumer<Object, Object> rc : resultCheckers)
 301                 rc.accept(abcBar, deserialInstance);
 302         } catch (IOException x) {
 303             throw new UncheckedIOException(x);
 304         } catch (ReflectiveOperationException x) {
 305             throw new InternalError(x);
 306         }
 307     }
 308 
 309     static Path testDir(String name) throws IOException {
 310         Path testRoot = Paths.get("FailureAtomicity-" + name);
 311         if (Files.exists(testRoot))
 312             FileUtils.deleteFileTreeWithRetry(testRoot);
 313         Files.createDirectory(testRoot);
 314         return testRoot;
 315     }
 316 
 317     static String platformPath(String p) { return p.replace("/", File.separator); }
 318     static String binaryName(String name) { return name.replace(".", "/"); }
 319     static String condRemove(String line, String pattern, boolean hasReadObject) {
 320         if (hasReadObject) { return line.replaceAll(pattern, ""); }
 321         else { return line; }
 322     }
 323     static String condReplace(String line, String... zebraFooType) {
 324         if (zebraFooType.length == 1) {
 325             return line.replaceAll("\\$foo_zebra_type", zebraFooType[0]);
 326         } else { return line; }
 327     }
 328     static String nameFromTemplate(Path template) {
 329         return template.getFileName().toString().replaceAll(".template", "");
 330     }
 331 
 332     static Path createSrc(String pkg, Path srcTemplate, Path srcRoot,
 333                           String zebraType, boolean hasReadObject,
 334                           String... zebraFooType)
 335         throws IOException
 336     {
 337         Path srcDst = srcRoot.resolve(platformPath(binaryName(pkg)));
 338         Files.createDirectories(srcDst);
 339         Path srcFile = srcDst.resolve(nameFromTemplate(srcTemplate) + ".java");
 340 
 341         List<String> lines = Files.lines(srcTemplate)
 342                 .map(s -> s.replaceAll("\\$package", pkg))
 343                 .map(s -> s.replaceAll("\\$zebra_type", zebraType))
 344                 .map(s -> condReplace(s, zebraFooType))
 345                 .map(s -> condRemove(s, "//\\$has_readObject", hasReadObject))
 346                 .collect(Collectors.toList());
 347         Files.write(srcFile, lines);
 348         return srcFile;
 349     }
 350 
 351     static void javac(Path dest, List<Path> sourceFiles) throws IOException {
 352         JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
 353         try (StandardJavaFileManager fileManager =
 354                      compiler.getStandardFileManager(null, null, null)) {
 355             List<File> files = sourceFiles.stream()
 356                                           .map(p -> p.toFile())
 357                                           .collect(Collectors.toList());
 358             Iterable<? extends JavaFileObject> compilationUnits =
 359                     fileManager.getJavaFileObjectsFromFiles(files);
 360             fileManager.setLocation(StandardLocation.CLASS_OUTPUT,
 361                                     Arrays.asList(dest.toFile()));
 362             fileManager.setLocation(StandardLocation.CLASS_PATH,
 363                                     Arrays.asList(TEST_CLASSES.toFile()));
 364             JavaCompiler.CompilationTask task = compiler
 365                     .getTask(null, fileManager, null, null, null, compilationUnits);
 366             boolean passed = task.call();
 367             if (!passed)
 368                 throw new RuntimeException("Error compiling " + files);
 369         }
 370     }
 371 
 372     static Object toOtherPkgInstance(Object obj, ClassLoader loader)
 373         throws IOException, ClassNotFoundException
 374     {
 375         byte[] bytes = serialize(obj);
 376         bytes = replacePkg(bytes);
 377         return deserialize(bytes, loader);
 378     }
 379 
 380     @SuppressWarnings("deprecation")
 381     static byte[] replacePkg(byte[] bytes) {
 382         String str = new String(bytes, 0);
 383         str = str.replaceAll(PKGS[0], PKGS[1]);
 384         str.getBytes(0, bytes.length, bytes, 0);
 385         return bytes;
 386     }
 387 
 388     static byte[] serialize(Object obj) throws IOException {
 389         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
 390              ObjectOutputStream out = new ObjectOutputStream(baos);) {
 391             out.writeObject(obj);
 392             return baos.toByteArray();
 393         }
 394     }
 395 
 396     static Object deserialize(byte[] data, ClassLoader l)
 397         throws IOException, ClassNotFoundException
 398     {
 399         return new WithLoaderObjectInputStream(new ByteArrayInputStream(data), l)
 400                 .readObject();
 401     }
 402 
 403     static class WithLoaderObjectInputStream extends ObjectInputStream {
 404         final ClassLoader loader;
 405         WithLoaderObjectInputStream(InputStream is, ClassLoader loader)
 406             throws IOException
 407         {
 408             super(is);
 409             this.loader = loader;
 410         }
 411         @Override
 412         protected Class<?> resolveClass(ObjectStreamClass desc)
 413             throws IOException, ClassNotFoundException {
 414             try {
 415                 return super.resolveClass(desc);
 416             } catch (ClassNotFoundException x) {
 417                 String name = desc.getName();
 418                 return Class.forName(name, false, loader);
 419             }
 420         }
 421     }
 422 }