src/java.base/share/classes/sun/security/ssl/HandshakeHash.java

Print this page

        

@@ -27,10 +27,11 @@
 package sun.security.ssl;
 
 import java.io.ByteArrayOutputStream;
 import java.security.*;
 import java.util.Locale;
+import java.nio.ByteBuffer;
 
 /**
  * Abstraction for the SSL/TLS hash of all handshake messages that is
  * maintained to verify the integrity of the negotiation. Internally,
  * it consists of an MD5 and an SHA1 digest. They are used in the client

@@ -97,19 +98,121 @@
     private final int clonesNeeded;    // needs to be saved for later use
 
     // For TLS 1.2
     private MessageDigest finMD;
 
+    // Cache for input record handshake hash computation
+    private ByteArrayOutputStream reserve = new ByteArrayOutputStream();
+
     /**
      * Create a new HandshakeHash. needCertificateVerify indicates whether
      * a hash for the certificate verify message is required.
      */
     HandshakeHash(boolean needCertificateVerify) {
         clonesNeeded = needCertificateVerify ? 3 : 2;
     }
 
+    void reserve(ByteBuffer input) {
+        if (input.hasArray()) {
+            reserve.write(input.array(),
+                    input.position() + input.arrayOffset(), input.remaining());
+        } else {
+            int inPos = input.position();
+            byte[] holder = new byte[input.remaining()];
+            input.get(holder);
+            input.position(inPos);
+            reserve.write(holder, 0, holder.length);
+        }
+    }
+
+    void reserve(byte[] b, int offset, int len) {
+        reserve.write(b, offset, len);
+    }
+
+    void reload() {
+        if (reserve.size() != 0) {
+            byte[] bytes = reserve.toByteArray();
+            reserve.reset();
+            update(bytes, 0, bytes.length);
+        }
+    }
+
+    void update(ByteBuffer input) {
+
+        // reload if there are reserved messages.
+        reload();
+
+        int inPos = input.position();
+        switch (version) {
+            case 1:
+                md5.update(input);
+                input.position(inPos);
+
+                sha.update(input);
+                input.position(inPos);
+
+                break;
+            default:
+                if (finMD != null) {
+                    finMD.update(input);
+                    input.position(inPos);
+                }
+                if (input.hasArray()) {
+                    data.write(input.array(),
+                            inPos + input.arrayOffset(), input.remaining());
+                } else {
+                    byte[] holder = new byte[input.remaining()];
+                    input.get(holder);
+                    input.position(inPos);
+                    data.write(holder, 0, holder.length);
+                }
+                break;
+        }
+    }
+
+    void update(byte handshakeType, byte[] handshakeBody) {
+
+        // reload if there are reserved messages.
+        reload();
+
+        switch (version) {
+            case 1:
+                md5.update(handshakeType);
+                sha.update(handshakeType);
+
+                md5.update((byte)((handshakeBody.length >> 16) & 0xFF));
+                sha.update((byte)((handshakeBody.length >> 16) & 0xFF));
+                md5.update((byte)((handshakeBody.length >> 8) & 0xFF));
+                sha.update((byte)((handshakeBody.length >> 8) & 0xFF));
+                md5.update((byte)(handshakeBody.length & 0xFF));
+                sha.update((byte)(handshakeBody.length & 0xFF));
+
+                md5.update(handshakeBody);
+                sha.update(handshakeBody);
+                break;
+            default:
+                if (finMD != null) {
+                    finMD.update(handshakeType);
+                    finMD.update((byte)((handshakeBody.length >> 16) & 0xFF));
+                    finMD.update((byte)((handshakeBody.length >> 8) & 0xFF));
+                    finMD.update((byte)(handshakeBody.length & 0xFF));
+                    finMD.update(handshakeBody);
+                }
+                data.write(handshakeType);
+                data.write((byte)((handshakeBody.length >> 16) & 0xFF));
+                data.write((byte)((handshakeBody.length >> 8) & 0xFF));
+                data.write((byte)(handshakeBody.length & 0xFF));
+                data.write(handshakeBody, 0, handshakeBody.length);
+                break;
+        }
+    }
+
     void update(byte[] b, int offset, int len) {
+
+        // reload if there are reserved messages.
+        reload();
+
         switch (version) {
             case 1:
                 md5.update(b, offset, len);
                 sha.update(b, offset, len);
                 break;

@@ -137,13 +240,19 @@
 
 
     void protocolDetermined(ProtocolVersion pv) {
 
         // Do not set again, will ignore
-        if (version != -1) return;
+        if (version != -1) {
+            return;
+        }
 
+        if (pv.maybeDTLSProtocol()) {
+            version = pv.compareTo(ProtocolVersion.DTLS12) >= 0 ? 2 : 1;
+        } else {
         version = pv.compareTo(ProtocolVersion.TLS12) >= 0 ? 2 : 1;
+        }
         switch (version) {
             case 1:
                 // initiate md5, sha and call update on saved array
                 try {
                     md5 = CloneableDigest.getDigest("MD5", clonesNeeded);