1 /*
   2  * Copyright (c) 2012, 2016, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.hotspot.replacements;
  24 
  25 import static org.graalvm.compiler.hotspot.HotSpotBackend.DECRYPT;
  26 import static org.graalvm.compiler.hotspot.HotSpotBackend.DECRYPT_WITH_ORIGINAL_KEY;
  27 import static org.graalvm.compiler.hotspot.HotSpotBackend.ENCRYPT;
  28 import static org.graalvm.compiler.hotspot.replacements.UnsafeAccess.UNSAFE;
  29 import static jdk.vm.ci.hotspot.HotSpotJVMCIRuntimeProvider.getArrayBaseOffset;
  30 
  31 import org.graalvm.compiler.api.replacements.ClassSubstitution;
  32 import org.graalvm.compiler.api.replacements.Fold;
  33 import org.graalvm.compiler.api.replacements.MethodSubstitution;
  34 import org.graalvm.compiler.core.common.LocationIdentity;
  35 import org.graalvm.compiler.core.common.spi.ForeignCallDescriptor;
  36 import org.graalvm.compiler.debug.GraalError;
  37 import org.graalvm.compiler.graph.Node.ConstantNodeParameter;
  38 import org.graalvm.compiler.graph.Node.NodeIntrinsic;
  39 import org.graalvm.compiler.hotspot.nodes.ComputeObjectAddressNode;
  40 import org.graalvm.compiler.nodes.PiNode;
  41 import org.graalvm.compiler.nodes.extended.ForeignCallNode;
  42 import org.graalvm.compiler.nodes.extended.UnsafeLoadNode;
  43 import org.graalvm.compiler.word.Pointer;
  44 import org.graalvm.compiler.word.Word;
  45 
  46 import jdk.vm.ci.meta.JavaKind;
  47 
  48 // JaCoCo Exclude
  49 
  50 /**
  51  * Substitutions for {@code com.sun.crypto.provider.CipherBlockChaining} methods.
  52  */
  53 @ClassSubstitution(className = "com.sun.crypto.provider.CipherBlockChaining", optional = true)
  54 public class CipherBlockChainingSubstitutions {
  55 
  56     private static final long embeddedCipherOffset;
  57     private static final long rOffset;
  58     private static final Class<?> cipherBlockChainingClass;
  59     private static final Class<?> feedbackCipherClass;
  60     static {
  61         try {
  62             // Need to use the system class loader as com.sun.crypto.provider.FeedbackCipher
  63             // is normally loaded by the extension class loader which is not delegated
  64             // to by the JVMCI class loader.
  65             ClassLoader cl = ClassLoader.getSystemClassLoader();
  66 
  67             feedbackCipherClass = Class.forName("com.sun.crypto.provider.FeedbackCipher", true, cl);
  68             embeddedCipherOffset = UNSAFE.objectFieldOffset(feedbackCipherClass.getDeclaredField("embeddedCipher"));
  69 
  70             cipherBlockChainingClass = Class.forName("com.sun.crypto.provider.CipherBlockChaining", true, cl);
  71             rOffset = UNSAFE.objectFieldOffset(cipherBlockChainingClass.getDeclaredField("r"));
  72         } catch (Exception ex) {
  73             throw new GraalError(ex);
  74         }
  75     }
  76 
  77     @Fold
  78     static Class<?> getAESCryptClass() {
  79         return AESCryptSubstitutions.AESCryptClass;
  80     }
  81 
  82     @MethodSubstitution(isStatic = false)
  83     static int encrypt(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset) {
  84         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
  85         Object embeddedCipher = UnsafeLoadNode.load(realReceiver, embeddedCipherOffset, JavaKind.Object, LocationIdentity.any());
  86         if (getAESCryptClass().isInstance(embeddedCipher)) {
  87             Object aesCipher = getAESCryptClass().cast(embeddedCipher);
  88             crypt(realReceiver, in, inOffset, inLength, out, outOffset, aesCipher, true, false);
  89             return inLength;
  90         } else {
  91             return encrypt(realReceiver, in, inOffset, inLength, out, outOffset);
  92         }
  93     }
  94 
  95     @MethodSubstitution(isStatic = false, value = "implEncrypt")
  96     static int implEncrypt(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset) {
  97         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
  98         Object embeddedCipher = UnsafeLoadNode.load(realReceiver, embeddedCipherOffset, JavaKind.Object, LocationIdentity.any());
  99         if (getAESCryptClass().isInstance(embeddedCipher)) {
 100             Object aesCipher = getAESCryptClass().cast(embeddedCipher);
 101             crypt(realReceiver, in, inOffset, inLength, out, outOffset, aesCipher, true, false);
 102             return inLength;
 103         } else {
 104             return implEncrypt(realReceiver, in, inOffset, inLength, out, outOffset);
 105         }
 106     }
 107 
 108     @MethodSubstitution(isStatic = false)
 109     static int decrypt(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset) {
 110         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
 111         Object embeddedCipher = UnsafeLoadNode.load(realReceiver, embeddedCipherOffset, JavaKind.Object, LocationIdentity.any());
 112         if (in != out && getAESCryptClass().isInstance(embeddedCipher)) {
 113             Object aesCipher = getAESCryptClass().cast(embeddedCipher);
 114             crypt(realReceiver, in, inOffset, inLength, out, outOffset, aesCipher, false, false);
 115             return inLength;
 116         } else {
 117             return decrypt(realReceiver, in, inOffset, inLength, out, outOffset);
 118         }
 119     }
 120 
 121     @MethodSubstitution(isStatic = false)
 122     static int implDecrypt(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset) {
 123         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
 124         Object embeddedCipher = UnsafeLoadNode.load(realReceiver, embeddedCipherOffset, JavaKind.Object, LocationIdentity.any());
 125         if (in != out && getAESCryptClass().isInstance(embeddedCipher)) {
 126             Object aesCipher = getAESCryptClass().cast(embeddedCipher);
 127             crypt(realReceiver, in, inOffset, inLength, out, outOffset, aesCipher, false, false);
 128             return inLength;
 129         } else {
 130             return implDecrypt(realReceiver, in, inOffset, inLength, out, outOffset);
 131         }
 132     }
 133 
 134     /**
 135      * Variation for platforms (e.g. SPARC) that need do key expansion in stubs due to compatibility
 136      * issues between Java key expansion and hardware crypto instructions.
 137      */
 138     @MethodSubstitution(isStatic = false, value = "decrypt")
 139     static int decryptWithOriginalKey(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset) {
 140         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
 141         Object embeddedCipher = UnsafeLoadNode.load(realReceiver, embeddedCipherOffset, JavaKind.Object, LocationIdentity.any());
 142         if (in != out && getAESCryptClass().isInstance(embeddedCipher)) {
 143             Object aesCipher = getAESCryptClass().cast(embeddedCipher);
 144             crypt(realReceiver, in, inOffset, inLength, out, outOffset, aesCipher, false, true);
 145             return inLength;
 146         } else {
 147             return decryptWithOriginalKey(realReceiver, in, inOffset, inLength, out, outOffset);
 148         }
 149     }
 150 
 151     /**
 152      * @see #decryptWithOriginalKey(Object, byte[], int, int, byte[], int)
 153      */
 154     @MethodSubstitution(isStatic = false, value = "implDecrypt")
 155     static int implDecryptWithOriginalKey(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset) {
 156         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
 157         Object embeddedCipher = UnsafeLoadNode.load(realReceiver, embeddedCipherOffset, JavaKind.Object, LocationIdentity.any());
 158         if (in != out && getAESCryptClass().isInstance(embeddedCipher)) {
 159             Object aesCipher = getAESCryptClass().cast(embeddedCipher);
 160             crypt(realReceiver, in, inOffset, inLength, out, outOffset, aesCipher, false, true);
 161             return inLength;
 162         } else {
 163             return implDecryptWithOriginalKey(realReceiver, in, inOffset, inLength, out, outOffset);
 164         }
 165     }
 166 
 167     private static void crypt(Object rcvr, byte[] in, int inOffset, int inLength, byte[] out, int outOffset, Object embeddedCipher, boolean encrypt, boolean withOriginalKey) {
 168         AESCryptSubstitutions.checkArgs(in, inOffset, out, outOffset);
 169         Object realReceiver = PiNode.piCastNonNull(rcvr, cipherBlockChainingClass);
 170         Object aesCipher = getAESCryptClass().cast(embeddedCipher);
 171         Object kObject = UnsafeLoadNode.load(aesCipher, AESCryptSubstitutions.kOffset, JavaKind.Object, LocationIdentity.any());
 172         Object rObject = UnsafeLoadNode.load(realReceiver, rOffset, JavaKind.Object, LocationIdentity.any());
 173         Pointer kAddr = Word.objectToTrackedPointer(kObject).add(getArrayBaseOffset(JavaKind.Int));
 174         Pointer rAddr = Word.objectToTrackedPointer(rObject).add(getArrayBaseOffset(JavaKind.Byte));
 175         Word inAddr = Word.unsigned(ComputeObjectAddressNode.get(in, getArrayBaseOffset(JavaKind.Byte) + inOffset));
 176         Word outAddr = Word.unsigned(ComputeObjectAddressNode.get(out, getArrayBaseOffset(JavaKind.Byte) + outOffset));
 177         if (encrypt) {
 178             encryptAESCryptStub(ENCRYPT, inAddr, outAddr, kAddr, rAddr, inLength);
 179         } else {
 180             if (withOriginalKey) {
 181                 Object lastKeyObject = UnsafeLoadNode.load(aesCipher, AESCryptSubstitutions.lastKeyOffset, JavaKind.Object, LocationIdentity.any());
 182                 Pointer lastKeyAddr = Word.objectToTrackedPointer(lastKeyObject).add(getArrayBaseOffset(JavaKind.Byte));
 183                 decryptAESCryptWithOriginalKeyStub(DECRYPT_WITH_ORIGINAL_KEY, inAddr, outAddr, kAddr, rAddr, inLength, lastKeyAddr);
 184             } else {
 185                 decryptAESCryptStub(DECRYPT, inAddr, outAddr, kAddr, rAddr, inLength);
 186             }
 187         }
 188     }
 189 
 190     @NodeIntrinsic(ForeignCallNode.class)
 191     public static native void encryptAESCryptStub(@ConstantNodeParameter ForeignCallDescriptor descriptor, Word in, Word out, Pointer key, Pointer r, int inLength);
 192 
 193     @NodeIntrinsic(ForeignCallNode.class)
 194     public static native void decryptAESCryptStub(@ConstantNodeParameter ForeignCallDescriptor descriptor, Word in, Word out, Pointer key, Pointer r, int inLength);
 195 
 196     @NodeIntrinsic(ForeignCallNode.class)
 197     public static native void decryptAESCryptWithOriginalKeyStub(@ConstantNodeParameter ForeignCallDescriptor descriptor, Word in, Word out, Pointer key, Pointer r, int inLength, Pointer originalKey);
 198 }