< prev index next >

src/java.base/share/classes/sun/security/ssl/SSLExtensions.java

Print this page

        

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 2006, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -24,140 +24,337 @@
  */
 
 package sun.security.ssl;
 
 import java.io.IOException;
-import java.io.PrintStream;
+import java.nio.ByteBuffer;
+import java.text.MessageFormat;
 import java.util.*;
-import javax.net.ssl.*;
+
+import sun.security.ssl.SSLHandshake.HandshakeMessage;
+import sun.security.util.HexDumpEncoder;
 
 /**
- * This file contains all the classes relevant to TLS Extensions for the
- * ClientHello and ServerHello messages. The extension mechanism and
- * several extensions are defined in RFC 6066. Additional extensions are
- * defined in the ECC RFC 4492 and the ALPN extension is defined in RFC 7301.
- *
- * Currently, only the two ECC extensions are fully supported.
- *
- * The classes contained in this file are:
- *  . HelloExtensions: a List of extensions as used in the client hello
- *      and server hello messages.
- *  . ExtensionType: an enum style class for the extension type
- *  . HelloExtension: abstract base class for all extensions. All subclasses
- *      must be immutable.
- *
- *  . UnknownExtension: used to represent all parsed extensions that we do not
- *      explicitly support.
- *  . ServerNameExtension: the server_name extension.
- *  . SignatureAlgorithmsExtension: the signature_algorithms extension.
- *  . SupportedGroupsExtension: the supported groups extension.
- *  . EllipticPointFormatsExtension: the ECC supported point formats
- *      (compressed/uncompressed) extension.
- *  . ALPNExtension: the application_layer_protocol_negotiation extension.
- *
- * @since   1.6
- * @author  Andreas Sterbenz
+ * SSL/(D)TLS extensions in a handshake message.
  */
-final class HelloExtensions {
-
-    private List<HelloExtension> extensions;
+final class SSLExtensions {
+    private final HandshakeMessage handshakeMessage;
+    private Map<SSLExtension, byte[]> extMap = new LinkedHashMap<>();
     private int encodedLength;
 
-    HelloExtensions() {
-        extensions = Collections.emptyList();
+    // Extension map for debug logging
+    private final Map<Integer, byte[]> logMap =
+            SSLLogger.isOn ? null : new LinkedHashMap<>();
+
+    SSLExtensions(HandshakeMessage handshakeMessage) {
+        this.handshakeMessage = handshakeMessage;
+        this.encodedLength = 2;         // 2: the length of the extensions.
     }
 
-    HelloExtensions(HandshakeInStream s) throws IOException {
-        int len = s.getInt16();
-        extensions = new ArrayList<HelloExtension>();
-        encodedLength = len + 2;
+    SSLExtensions(HandshakeMessage hm,
+            ByteBuffer m, SSLExtension[] extensions) throws IOException {
+        this.handshakeMessage = hm;
+
+        int len = Record.getInt16(m);
+        encodedLength = len + 2;        // 2: the length of the extensions.
         while (len > 0) {
-            int type = s.getInt16();
-            int extlen = s.getInt16();
-            ExtensionType extType = ExtensionType.get(type);
-            HelloExtension extension;
-            if (extType == ExtensionType.EXT_SERVER_NAME) {
-                extension = new ServerNameExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_SIGNATURE_ALGORITHMS) {
-                extension = new SignatureAlgorithmsExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_SUPPORTED_GROUPS) {
-                extension = new SupportedGroupsExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_EC_POINT_FORMATS) {
-                extension = new EllipticPointFormatsExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_RENEGOTIATION_INFO) {
-                extension = new RenegotiationInfoExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_ALPN) {
-                extension = new ALPNExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_MAX_FRAGMENT_LENGTH) {
-                extension = new MaxFragmentLengthExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_STATUS_REQUEST) {
-                extension = new CertStatusReqExtension(s, extlen);
-            } else if (extType == ExtensionType.EXT_STATUS_REQUEST_V2) {
-                extension = new CertStatusReqListV2Extension(s, extlen);
-            } else if (extType == ExtensionType.EXT_EXTENDED_MASTER_SECRET) {
-                extension = new ExtendedMasterSecretExtension(s, extlen);
-            } else {
-                extension = new UnknownExtension(s, extlen, extType);
-            }
-            extensions.add(extension);
-            len -= extlen + 4;
-        }
-        if (len != 0) {
-            throw new SSLProtocolException(
-                        "Error parsing extensions: extra data");
+            int extId = Record.getInt16(m);
+            int extLen = Record.getInt16(m);
+            if (extLen > m.remaining()) {
+                hm.handshakeContext.conContext.fatal(Alert.ILLEGAL_PARAMETER,
+                        "Error parsing extension (" + extId +
+                        "): no sufficient data");
+            }
+
+            SSLHandshake handshakeType = hm.handshakeType();
+            if (SSLExtension.isConsumable(extId) &&
+                    SSLExtension.valueOf(handshakeType, extId) == null) {
+                hm.handshakeContext.conContext.fatal(
+                        Alert.UNSUPPORTED_EXTENSION,
+                        "extension (" + extId +
+                        ") should not be presented in " + handshakeType.name);
+            }
+
+            boolean isSupported = false;
+            for (SSLExtension extension : extensions) {
+                if ((extension.id != extId) ||
+                        (extension.onLoadConcumer == null)) {
+                    continue;
+                }
+
+                if (extension.handshakeType != handshakeType) {
+                    hm.handshakeContext.conContext.fatal(
+                            Alert.UNSUPPORTED_EXTENSION,
+                            "extension (" + extId + ") should not be " +
+                            "presented in " + handshakeType.name);
+                }
+
+                byte[] extData = new byte[extLen];
+                m.get(extData);
+                extMap.put(extension, extData);
+                if (logMap != null) {
+                    logMap.put(extId, extData);
+                }
+
+                isSupported = true;
+                break;
+            }
+
+            if (!isSupported) {
+                if (logMap != null) {
+                    // cache the extension for debug logging
+                    byte[] extData = new byte[extLen];
+                    m.get(extData);
+                    logMap.put(extId, extData);
+
+                    if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                        SSLLogger.fine(
+                                "Ignore unknown or unsupported extension",
+                                toString(extId, extData));
+                    }
+                } else {
+                    // ignore the extension
+                    int pos = m.position() + extLen;
+                    m.position(pos);
         }
     }
 
-    // Return the List of extensions. Must not be modified by the caller.
-    List<HelloExtension> list() {
-        return extensions;
-    }
-
-    void add(HelloExtension ext) {
-        if (extensions.isEmpty()) {
-            extensions = new ArrayList<HelloExtension>();
-        }
-        extensions.add(ext);
-        encodedLength = -1;
-    }
-
-    HelloExtension get(ExtensionType type) {
-        for (HelloExtension ext : extensions) {
-            if (ext.type == type) {
-                return ext;
+            len -= extLen + 4;
             }
         }
-        return null;
+
+    byte[] get(SSLExtension ext) {
+        return extMap.get(ext);
     }
 
-    int length() {
-        if (encodedLength >= 0) {
-            return encodedLength;
+    /**
+     * Consume the specified extensions.
+     */
+    void consumeOnLoad(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (context.negotiatedProtocol != null &&
+                    !extension.isAvailable(context.negotiatedProtocol)) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.fine(
+                        "Ignore unsupported extension: " + extension.name);
+                }
+                continue;
+                // context.conContext.fatal(Alert.UNSUPPORTED_EXTENSION,
+                //         context.negotiatedProtocol + " does not support " +
+                //         extension + " extension");
+            }
+
+            if (!extMap.containsKey(extension)) {
+                if (extension.onLoadAbsence != null) {
+                    extension.absent(context, handshakeMessage);
+                } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.fine(
+                        "Ignore unavailable extension: " + extension.name);
+                }
+                continue;
+            }
+
+
+            if (extension.onLoadConcumer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                        "Ignore unsupported extension: " + extension.name);
+                }
+                continue;
+            }
+
+            ByteBuffer m = ByteBuffer.wrap(extMap.get(extension));
+            extension.consumeOnLoad(context, handshakeMessage, m);
+
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                SSLLogger.fine("Consumed extension: " + extension.name);
+            }
+        }
+    }
+
+    /**
+     * Consider impact of the specified extensions.
+     */
+    void consumeOnTrade(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (!extMap.containsKey(extension)) {
+                // No impact could be expected, so just ignore the absence.
+                continue;
+            }
+
+            if (extension.onTradeConsumer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                            "Ignore impact of unsupported extension: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            extension.consumeOnTrade(context, handshakeMessage);
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                SSLLogger.fine("Populated with extension: " + extension.name);
+            }
+        }
         }
-        if (extensions.isEmpty()) {
-            encodedLength = 0;
+
+    /**
+     * Produce extension values for the specified extensions.
+     */
+    void produce(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (extMap.containsKey(extension)) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.fine(
+                            "Ignore, duplicated extension: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            if (extension.networkProducer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                            "Ignore, no extension producer defined: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            byte[] encoded = extension.produce(context, handshakeMessage);
+            if (encoded != null) {
+                extMap.put(extension, encoded);
+                encodedLength += encoded.length + 4; // extension_type (2)
+                                                     // extension_data length(2)
+            } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                // The extension is not available in the context.
+                SSLLogger.fine(
+                        "Ignore, context unavailable extension: " +
+                        extension.name);
+            }
+        }
+    }
+
+    /**
+     * Produce extension values for the specified extensions, replacing if
+     * there is an existing extension value for a specified extension.
+     */
+    void reproduce(HandshakeContext context,
+            SSLExtension[] extensions) throws IOException {
+        for (SSLExtension extension : extensions) {
+            if (extension.networkProducer == null) {
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                    SSLLogger.warning(
+                            "Ignore, no extension producer defined: " +
+                            extension.name);
+                }
+                continue;
+            }
+
+            byte[] encoded = extension.produce(context, handshakeMessage);
+            if (encoded != null) {
+                if (extMap.containsKey(extension)) {
+                    byte[] old = extMap.replace(extension, encoded);
+                    if (old != null) {
+                        encodedLength -= old.length + 4;
+                    }
+                    encodedLength += encoded.length + 4;
         } else {
-            encodedLength = 2;
-            for (HelloExtension ext : extensions) {
-                encodedLength += ext.length();
+                    extMap.put(extension, encoded);
+                    encodedLength += encoded.length + 4;
+                                                    // extension_type (2)
+                                                    // extension_data length(2)
             }
+            } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
+                // The extension is not available in the context.
+                SSLLogger.fine(
+                        "Ignore, context unavailable extension: " +
+                        extension.name);
         }
+        }
+    }
+
+    // Note that TLS 1.3 may use empty extensions.  Please consider it while
+    // using this method.
+    int length() {
+        if (extMap.isEmpty()) {
+            return 0;
+        } else {
         return encodedLength;
     }
+    }
 
-    void send(HandshakeOutStream s) throws IOException {
-        int length = length();
-        if (length == 0) {
+    // Note that TLS 1.3 may use empty extensions.  Please consider it while
+    // using this method.
+    void send(HandshakeOutStream hos) throws IOException {
+        int extsLen = length();
+        if (extsLen == 0) {
             return;
         }
-        s.putInt16(length - 2);
-        for (HelloExtension ext : extensions) {
-            ext.send(s);
+        hos.putInt16(extsLen - 2);
+        // extensions must be sent in the order they appear in the enum
+        for (SSLExtension ext : SSLExtension.values()) {
+            byte[] extData = extMap.get(ext);
+            if (extData != null) {
+                hos.putInt16(ext.id);
+                hos.putBytes16(extData);
+            }
+        }
+    }
+
+    @Override
+    public String toString() {
+        if (extMap.isEmpty() && (logMap == null || logMap.isEmpty())) {
+            return "<no extension>";
+        } else {
+            StringBuilder builder = new StringBuilder(512);
+            if (logMap != null) {
+                for (Map.Entry<Integer, byte[]> en : logMap.entrySet()) {
+                    SSLExtension ext = SSLExtension.valueOf(
+                            handshakeMessage.handshakeType(), en.getKey());
+                    if (builder.length() != 0) {
+                        builder.append(",\n");
+                    }
+                    if (ext != null) {
+                        builder.append(
+                                ext.toString(ByteBuffer.wrap(en.getValue())));
+                    } else {
+                        builder.append(toString(en.getKey(), en.getValue()));
+                    }
         }
+
+                return builder.toString();
+            } else {
+                for (Map.Entry<SSLExtension, byte[]> en : extMap.entrySet()) {
+                    if (builder.length() != 0) {
+                        builder.append(",\n");
+                    }
+                    builder.append(
+                        en.getKey().toString(ByteBuffer.wrap(en.getValue())));
     }
 
-    void print(PrintStream s) throws IOException {
-        for (HelloExtension ext : extensions) {
-            s.println(ext.toString());
+                return builder.toString();
+            }
+        }
         }
+
+    private static String toString(int extId, byte[] extData) {
+        MessageFormat messageFormat = new MessageFormat(
+            "\"unknown extension ({0})\": '{'\n" +
+            "{1}\n" +
+            "'}'",
+            Locale.ENGLISH);
+
+        HexDumpEncoder hexEncoder = new HexDumpEncoder();
+        String encoded = hexEncoder.encodeBuffer(extData);
+
+        Object[] messageFields = {
+            extId,
+            Utilities.indent(encoded)
+        };
+
+        return messageFormat.format(messageFields);
     }
 }
< prev index next >