/* * Copyright (c) 2014, 2014, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package org.graalvm.compiler.jtt.except; import org.junit.BeforeClass; import org.junit.Test; import org.graalvm.compiler.jtt.JTTTest; import org.graalvm.compiler.test.ExportingClassLoader; import jdk.internal.org.objectweb.asm.ClassWriter; import jdk.internal.org.objectweb.asm.MethodVisitor; import jdk.internal.org.objectweb.asm.Opcodes; import jdk.internal.org.objectweb.asm.Type; public class UntrustedInterfaces extends JTTTest { public interface CallBack { int callBack(TestInterface ti); } private interface TestInterface { int method(); } /** * What a GoodPill would look like. * *
     * private static final class GoodPill extends Pill {
     *     public void setField() {
     *         field = new TestConstant();
     *     }
     *
     *     public void setStaticField() {
     *         staticField = new TestConstant();
     *     }
     *
     *     public int callMe(CallBack callback) {
     *         return callback.callBack(new TestConstant());
     *     }
     *
     *     public TestInterface get() {
     *         return new TestConstant();
     *     }
     * }
     *
     * private static final class TestConstant implements TestInterface {
     *     public int method() {
     *         return 42;
     *     }
     * }
     * 
*/ public abstract static class Pill { public static TestInterface staticField; public TestInterface field; public abstract void setField(); public abstract void setStaticField(); public abstract int callMe(CallBack callback); public abstract TestInterface get(); } public int callBack(TestInterface list) { return list.method(); } public int staticFieldInvoke(Pill pill) { pill.setStaticField(); return Pill.staticField.method(); } public int fieldInvoke(Pill pill) { pill.setField(); return pill.field.method(); } public int argumentInvoke(Pill pill) { return pill.callMe(ti -> ti.method()); } public int returnInvoke(Pill pill) { return pill.get().method(); } @SuppressWarnings("cast") public boolean staticFieldInstanceof(Pill pill) { pill.setStaticField(); return Pill.staticField instanceof TestInterface; } @SuppressWarnings("cast") public boolean fieldInstanceof(Pill pill) { pill.setField(); return pill.field instanceof TestInterface; } @SuppressWarnings("cast") public int argumentInstanceof(Pill pill) { return pill.callMe(ti -> ti instanceof TestInterface ? 42 : 24); } @SuppressWarnings("cast") public boolean returnInstanceof(Pill pill) { return pill.get() instanceof TestInterface; } public TestInterface staticFieldCheckcast(Pill pill) { pill.setStaticField(); return TestInterface.class.cast(Pill.staticField); } public TestInterface fieldCheckcast(Pill pill) { pill.setField(); return TestInterface.class.cast(pill.field); } public int argumentCheckcast(Pill pill) { return pill.callMe(ti -> TestInterface.class.cast(ti).method()); } public TestInterface returnCheckcast(Pill pill) { return TestInterface.class.cast(pill.get()); } private static Pill poisonPill; // Checkstyle: stop @BeforeClass public static void setUp() throws InstantiationException, IllegalAccessException, ClassNotFoundException { poisonPill = (Pill) new PoisonLoader().findClass(PoisonLoader.POISON_IMPL_NAME).newInstance(); } // Checkstyle: resume @Test public void testStaticField0() { runTest("staticFieldInvoke", poisonPill); } @Test public void testStaticField1() { runTest("staticFieldInstanceof", poisonPill); } @Test public void testStaticField2() { runTest("staticFieldCheckcast", poisonPill); } @Test public void testField0() { runTest("fieldInvoke", poisonPill); } @Test public void testField1() { runTest("fieldInstanceof", poisonPill); } @Test public void testField2() { runTest("fieldCheckcast", poisonPill); } @Test public void testArgument0() { runTest("argumentInvoke", poisonPill); } @Test public void testArgument1() { runTest("argumentInstanceof", poisonPill); } @Test public void testArgument2() { runTest("argumentCheckcast", poisonPill); } @Test public void testReturn0() { runTest("returnInvoke", poisonPill); } @Test public void testReturn1() { runTest("returnInstanceof", poisonPill); } @Test public void testReturn2() { runTest("returnCheckcast", poisonPill); } private static class PoisonLoader extends ExportingClassLoader { public static final String POISON_IMPL_NAME = "org.graalvm.compiler.jtt.except.PoisonPill"; @Override protected Class findClass(String name) throws ClassNotFoundException { if (name.equals(POISON_IMPL_NAME)) { ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES); cw.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, POISON_IMPL_NAME.replace('.', '/'), null, Type.getInternalName(Pill.class), null); // constructor MethodVisitor constructor = cw.visitMethod(Opcodes.ACC_PUBLIC, "", "()V", null, null); constructor.visitCode(); constructor.visitVarInsn(Opcodes.ALOAD, 0); constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Pill.class), "", "()V", false); constructor.visitInsn(Opcodes.RETURN); constructor.visitMaxs(0, 0); constructor.visitEnd(); MethodVisitor setList = cw.visitMethod(Opcodes.ACC_PUBLIC, "setField", "()V", null, null); setList.visitCode(); setList.visitVarInsn(Opcodes.ALOAD, 0); setList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class)); setList.visitInsn(Opcodes.DUP); setList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "", "()V", false); setList.visitFieldInsn(Opcodes.PUTFIELD, Type.getInternalName(Pill.class), "field", Type.getDescriptor(TestInterface.class)); setList.visitInsn(Opcodes.RETURN); setList.visitMaxs(0, 0); setList.visitEnd(); MethodVisitor setStaticList = cw.visitMethod(Opcodes.ACC_PUBLIC, "setStaticField", "()V", null, null); setStaticList.visitCode(); setStaticList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class)); setStaticList.visitInsn(Opcodes.DUP); setStaticList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "", "()V", false); setStaticList.visitFieldInsn(Opcodes.PUTSTATIC, Type.getInternalName(Pill.class), "staticField", Type.getDescriptor(TestInterface.class)); setStaticList.visitInsn(Opcodes.RETURN); setStaticList.visitMaxs(0, 0); setStaticList.visitEnd(); MethodVisitor callMe = cw.visitMethod(Opcodes.ACC_PUBLIC, "callMe", Type.getMethodDescriptor(Type.INT_TYPE, Type.getType(CallBack.class)), null, null); callMe.visitCode(); callMe.visitVarInsn(Opcodes.ALOAD, 1); callMe.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class)); callMe.visitInsn(Opcodes.DUP); callMe.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "", "()V", false); callMe.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(CallBack.class), "callBack", Type.getMethodDescriptor(Type.INT_TYPE, Type.getType(TestInterface.class)), true); callMe.visitInsn(Opcodes.IRETURN); callMe.visitMaxs(0, 0); callMe.visitEnd(); MethodVisitor getList = cw.visitMethod(Opcodes.ACC_PUBLIC, "get", Type.getMethodDescriptor(Type.getType(TestInterface.class)), null, null); getList.visitCode(); getList.visitTypeInsn(Opcodes.NEW, Type.getInternalName(Object.class)); getList.visitInsn(Opcodes.DUP); getList.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(Object.class), "", "()V", false); getList.visitInsn(Opcodes.ARETURN); getList.visitMaxs(0, 0); getList.visitEnd(); cw.visitEnd(); byte[] bytes = cw.toByteArray(); return defineClass(name, bytes, 0, bytes.length); } return super.findClass(name); } } }