1 /*
   2  * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package sun.security.ssl;
  27 
  28 import java.io.IOException;
  29 import java.nio.ByteBuffer;
  30 import java.text.MessageFormat;
  31 import java.util.*;
  32 
  33 import sun.security.ssl.SSLHandshake.HandshakeMessage;
  34 import sun.security.util.HexDumpEncoder;
  35 
  36 /**
  37  * SSL/(D)TLS extensions in a handshake message.
  38  */
  39 final class SSLExtensions {
  40     private final HandshakeMessage handshakeMessage;
  41     private Map<SSLExtension, byte[]> extMap = new LinkedHashMap<>();
  42     private int encodedLength;
  43 
  44     // Extension map for debug logging
  45     private final Map<Integer, byte[]> logMap =
  46             SSLLogger.isOn ? new LinkedHashMap<>() : null;
  47 
  48     SSLExtensions(HandshakeMessage handshakeMessage) {
  49         this.handshakeMessage = handshakeMessage;
  50         this.encodedLength = 2;         // 2: the length of the extensions.
  51     }
  52 
  53     SSLExtensions(HandshakeMessage hm,
  54             ByteBuffer m, SSLExtension[] extensions) throws IOException {
  55         this.handshakeMessage = hm;
  56 
  57         int len = Record.getInt16(m);
  58         encodedLength = len + 2;        // 2: the length of the extensions.
  59         while (len > 0) {
  60             int extId = Record.getInt16(m);
  61             int extLen = Record.getInt16(m);
  62             if (extLen > m.remaining()) {
  63                 throw hm.handshakeContext.conContext.fatal(
  64                         Alert.ILLEGAL_PARAMETER,
  65                         "Error parsing extension (" + extId +
  66                         "): no sufficient data");
  67             }
  68 
  69             boolean isSupported = true;
  70             SSLHandshake handshakeType = hm.handshakeType();
  71             if (SSLExtension.isConsumable(extId) &&
  72                     SSLExtension.valueOf(handshakeType, extId) == null) {
  73                 if (extId == SSLExtension.CH_SUPPORTED_GROUPS.id &&
  74                         handshakeType == SSLHandshake.SERVER_HELLO) {
  75                     // Note: It does not comply to the specification.  However,
  76                     // there are servers that send the supported_groups
  77                     // extension in ServerHello handshake message.
  78                     //
  79                     // TLS 1.3 should not send this extension.   We may want to
  80                     // limit the workaround for TLS 1.2 and prior version only.
  81                     // However, the implementation of the limit is complicated
  82                     // and inefficient, and may not worthy the maintenance.
  83                     isSupported = false;
  84                     if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
  85                         SSLLogger.warning(
  86                                 "Received buggy supported_groups extension " +
  87                                 "in the ServerHello handshake message");
  88                     }
  89                 } else if (handshakeType == SSLHandshake.SERVER_HELLO) {
  90                     throw hm.handshakeContext.conContext.fatal(
  91                             Alert.UNSUPPORTED_EXTENSION, "extension (" +
  92                                     extId + ") should not be presented in " +
  93                                     handshakeType.name);
  94                 } else {
  95                     isSupported = false;
  96                     // debug log to ignore unknown extension for handshakeType
  97                 }
  98             }
  99 
 100             if (isSupported) {
 101                 isSupported = false;
 102                 for (SSLExtension extension : extensions) {
 103                     if ((extension.id != extId) ||
 104                             (extension.onLoadConsumer == null)) {
 105                         continue;
 106                     }
 107 
 108                     if (extension.handshakeType != handshakeType) {
 109                         throw hm.handshakeContext.conContext.fatal(
 110                                 Alert.UNSUPPORTED_EXTENSION,
 111                                 "extension (" + extId + ") should not be " +
 112                                 "presented in " + handshakeType.name);
 113                     }
 114 
 115                     byte[] extData = new byte[extLen];
 116                     m.get(extData);
 117                     extMap.put(extension, extData);
 118                     if (logMap != null) {
 119                         logMap.put(extId, extData);
 120                     }
 121 
 122                     isSupported = true;
 123                     break;
 124                 }
 125             }
 126 
 127             if (!isSupported) {
 128                 if (logMap != null) {
 129                     // cache the extension for debug logging
 130                     byte[] extData = new byte[extLen];
 131                     m.get(extData);
 132                     logMap.put(extId, extData);
 133 
 134                     if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 135                         SSLLogger.fine(
 136                                 "Ignore unknown or unsupported extension",
 137                                 toString(extId, extData));
 138                     }
 139                 } else {
 140                     // ignore the extension
 141                     int pos = m.position() + extLen;
 142                     m.position(pos);
 143                 }
 144             }
 145 
 146             len -= extLen + 4;
 147         }
 148     }
 149 
 150     byte[] get(SSLExtension ext) {
 151         return extMap.get(ext);
 152     }
 153 
 154     /**
 155      * Consume the specified extensions.
 156      */
 157     void consumeOnLoad(HandshakeContext context,
 158             SSLExtension[] extensions) throws IOException {
 159         for (SSLExtension extension : extensions) {
 160             if (context.negotiatedProtocol != null &&
 161                     !extension.isAvailable(context.negotiatedProtocol)) {
 162                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 163                     SSLLogger.fine(
 164                         "Ignore unsupported extension: " + extension.name);
 165                 }
 166                 continue;
 167             }
 168 
 169             if (!extMap.containsKey(extension)) {
 170                 if (extension.onLoadAbsence != null) {
 171                     extension.absentOnLoad(context, handshakeMessage);
 172                 } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 173                     SSLLogger.fine(
 174                         "Ignore unavailable extension: " + extension.name);
 175                 }
 176                 continue;
 177             }
 178 
 179 
 180             if (extension.onLoadConsumer == null) {
 181                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 182                     SSLLogger.warning(
 183                         "Ignore unsupported extension: " + extension.name);
 184                 }
 185                 continue;
 186             }
 187 
 188             ByteBuffer m = ByteBuffer.wrap(extMap.get(extension));
 189             extension.consumeOnLoad(context, handshakeMessage, m);
 190 
 191             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 192                 SSLLogger.fine("Consumed extension: " + extension.name);
 193             }
 194         }
 195     }
 196 
 197     /**
 198      * Consider impact of the specified extensions.
 199      */
 200     void consumeOnTrade(HandshakeContext context,
 201             SSLExtension[] extensions) throws IOException {
 202         for (SSLExtension extension : extensions) {
 203             if (!extMap.containsKey(extension)) {
 204                 if (extension.onTradeAbsence != null) {
 205                     extension.absentOnTrade(context, handshakeMessage);
 206                 } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 207                     SSLLogger.fine(
 208                         "Ignore unavailable extension: " + extension.name);
 209                 }
 210                 continue;
 211             }
 212 
 213             if (extension.onTradeConsumer == null) {
 214                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 215                     SSLLogger.warning(
 216                             "Ignore impact of unsupported extension: " +
 217                             extension.name);
 218                 }
 219                 continue;
 220             }
 221 
 222             extension.consumeOnTrade(context, handshakeMessage);
 223             if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 224                 SSLLogger.fine("Populated with extension: " + extension.name);
 225             }
 226         }
 227     }
 228 
 229     /**
 230      * Produce extension values for the specified extensions.
 231      */
 232     void produce(HandshakeContext context,
 233             SSLExtension[] extensions) throws IOException {
 234         for (SSLExtension extension : extensions) {
 235             if (extMap.containsKey(extension)) {
 236                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 237                     SSLLogger.fine(
 238                             "Ignore, duplicated extension: " +
 239                             extension.name);
 240                 }
 241                 continue;
 242             }
 243 
 244             if (extension.networkProducer == null) {
 245                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 246                     SSLLogger.warning(
 247                             "Ignore, no extension producer defined: " +
 248                             extension.name);
 249                 }
 250                 continue;
 251             }
 252 
 253             byte[] encoded = extension.produce(context, handshakeMessage);
 254             if (encoded != null) {
 255                 extMap.put(extension, encoded);
 256                 encodedLength += encoded.length + 4; // extension_type (2)
 257                                                      // extension_data length(2)
 258             } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 259                 // The extension is not available in the context.
 260                 SSLLogger.fine(
 261                         "Ignore, context unavailable extension: " +
 262                         extension.name);
 263             }
 264         }
 265     }
 266 
 267     /**
 268      * Produce extension values for the specified extensions, replacing if
 269      * there is an existing extension value for a specified extension.
 270      */
 271     void reproduce(HandshakeContext context,
 272             SSLExtension[] extensions) throws IOException {
 273         for (SSLExtension extension : extensions) {
 274             if (extension.networkProducer == null) {
 275                 if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 276                     SSLLogger.warning(
 277                             "Ignore, no extension producer defined: " +
 278                             extension.name);
 279                 }
 280                 continue;
 281             }
 282 
 283             byte[] encoded = extension.produce(context, handshakeMessage);
 284             if (encoded != null) {
 285                 if (extMap.containsKey(extension)) {
 286                     byte[] old = extMap.replace(extension, encoded);
 287                     if (old != null) {
 288                         encodedLength -= old.length + 4;
 289                     }
 290                     encodedLength += encoded.length + 4;
 291                 } else {
 292                     extMap.put(extension, encoded);
 293                     encodedLength += encoded.length + 4;
 294                                                     // extension_type (2)
 295                                                     // extension_data length(2)
 296                 }
 297             } else if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
 298                 // The extension is not available in the context.
 299                 SSLLogger.fine(
 300                         "Ignore, context unavailable extension: " +
 301                         extension.name);
 302             }
 303         }
 304     }
 305 
 306     // Note that TLS 1.3 may use empty extensions.  Please consider it while
 307     // using this method.
 308     int length() {
 309         if (extMap.isEmpty()) {
 310             return 0;
 311         } else {
 312             return encodedLength;
 313         }
 314     }
 315 
 316     // Note that TLS 1.3 may use empty extensions.  Please consider it while
 317     // using this method.
 318     void send(HandshakeOutStream hos) throws IOException {
 319         int extsLen = length();
 320         if (extsLen == 0) {
 321             return;
 322         }
 323         hos.putInt16(extsLen - 2);
 324         // extensions must be sent in the order they appear in the enum
 325         for (SSLExtension ext : SSLExtension.values()) {
 326             byte[] extData = extMap.get(ext);
 327             if (extData != null) {
 328                 hos.putInt16(ext.id);
 329                 hos.putBytes16(extData);
 330             }
 331         }
 332     }
 333 
 334     @Override
 335     public String toString() {
 336         if (extMap.isEmpty() && (logMap == null || logMap.isEmpty())) {
 337             return "<no extension>";
 338         } else {
 339             StringBuilder builder = new StringBuilder(512);
 340             if (logMap != null && !logMap.isEmpty()) {
 341                 for (Map.Entry<Integer, byte[]> en : logMap.entrySet()) {
 342                     SSLExtension ext = SSLExtension.valueOf(
 343                             handshakeMessage.handshakeType(), en.getKey());
 344                     if (builder.length() != 0) {
 345                         builder.append(",\n");
 346                     }
 347                     if (ext != null) {
 348                         builder.append(
 349                                 ext.toString(ByteBuffer.wrap(en.getValue())));
 350                     } else {
 351                         builder.append(toString(en.getKey(), en.getValue()));
 352                     }
 353                 }
 354 
 355                 return builder.toString();
 356             } else {
 357                 for (Map.Entry<SSLExtension, byte[]> en : extMap.entrySet()) {
 358                     if (builder.length() != 0) {
 359                         builder.append(",\n");
 360                     }
 361                     builder.append(
 362                         en.getKey().toString(ByteBuffer.wrap(en.getValue())));
 363                 }
 364 
 365                 return builder.toString();
 366             }
 367         }
 368     }
 369 
 370     private static String toString(int extId, byte[] extData) {
 371         String extName = SSLExtension.nameOf(extId);
 372         MessageFormat messageFormat = new MessageFormat(
 373             "\"{0} ({1})\": '{'\n" +
 374             "{2}\n" +
 375             "'}'",
 376             Locale.ENGLISH);
 377 
 378         HexDumpEncoder hexEncoder = new HexDumpEncoder();
 379         String encoded = hexEncoder.encodeBuffer(extData);
 380 
 381         Object[] messageFields = {
 382             extName,
 383             extId,
 384             Utilities.indent(encoded)
 385         };
 386 
 387         return messageFormat.format(messageFields);
 388     }
 389 }