1 /* 2 * Copyright (c) 2015, 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package jdk.incubator.http.internal.websocket; 27 28 import jdk.incubator.http.internal.common.MinimalFuture; 29 30 import java.io.IOException; 31 import java.net.URI; 32 import java.net.URISyntaxException; 33 import jdk.incubator.http.HttpClient; 34 import jdk.incubator.http.HttpClient.Version; 35 import jdk.incubator.http.HttpHeaders; 36 import jdk.incubator.http.HttpRequest; 37 import jdk.incubator.http.HttpResponse; 38 import jdk.incubator.http.HttpResponse.BodyHandler; 39 import jdk.incubator.http.WebSocketHandshakeException; 40 import jdk.incubator.http.internal.common.Pair; 41 42 import java.nio.charset.StandardCharsets; 43 import java.security.MessageDigest; 44 import java.security.NoSuchAlgorithmException; 45 import java.security.SecureRandom; 46 import java.time.Duration; 47 import java.util.Base64; 48 import java.util.Collection; 49 import java.util.Collections; 50 import java.util.LinkedHashSet; 51 import java.util.List; 52 import java.util.Optional; 53 import java.util.Set; 54 import java.util.TreeSet; 55 import java.util.concurrent.CompletableFuture; 56 import java.util.stream.Collectors; 57 58 import static java.lang.String.format; 59 import static jdk.incubator.http.internal.common.Utils.isValidName; 60 import static jdk.incubator.http.internal.common.Utils.stringOf; 61 62 final class OpeningHandshake { 63 64 private static final String HEADER_CONNECTION = "Connection"; 65 private static final String HEADER_UPGRADE = "Upgrade"; 66 private static final String HEADER_ACCEPT = "Sec-WebSocket-Accept"; 67 private static final String HEADER_EXTENSIONS = "Sec-WebSocket-Extensions"; 68 private static final String HEADER_KEY = "Sec-WebSocket-Key"; 69 private static final String HEADER_PROTOCOL = "Sec-WebSocket-Protocol"; 70 private static final String HEADER_VERSION = "Sec-WebSocket-Version"; 71 72 private static final Set<String> FORBIDDEN_HEADERS; 73 74 static { 75 FORBIDDEN_HEADERS = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); 76 FORBIDDEN_HEADERS.addAll(List.of(HEADER_ACCEPT, 77 HEADER_EXTENSIONS, 78 HEADER_KEY, 79 HEADER_PROTOCOL, 80 HEADER_VERSION)); 81 } 82 83 private static final SecureRandom srandom = new SecureRandom(); 84 85 private final MessageDigest sha1; 86 private final HttpClient client; 87 88 { 89 try { 90 sha1 = MessageDigest.getInstance("SHA-1"); 91 } catch (NoSuchAlgorithmException e) { 92 // Shouldn't happen: SHA-1 must be available in every Java platform 93 // implementation 94 throw new InternalError("Minimum requirements", e); 95 } 96 } 97 98 private final HttpRequest request; 99 private final Collection<String> subprotocols; 100 private final String nonce; 101 102 OpeningHandshake(BuilderImpl b) { 103 this.client = b.getClient(); 104 URI httpURI = createRequestURI(b.getUri()); 105 HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(httpURI); 106 Duration connectTimeout = b.getConnectTimeout(); 107 if (connectTimeout != null) { 108 requestBuilder.timeout(connectTimeout); 109 } 110 for (Pair<String, String> p : b.getHeaders()) { 111 if (FORBIDDEN_HEADERS.contains(p.first)) { 112 throw illegal("Illegal header: " + p.first); 113 } 114 requestBuilder.header(p.first, p.second); 115 } 116 this.subprotocols = createRequestSubprotocols(b.getSubprotocols()); 117 if (!this.subprotocols.isEmpty()) { 118 String p = this.subprotocols.stream().collect(Collectors.joining(", ")); 119 requestBuilder.header(HEADER_PROTOCOL, p); 120 } 121 requestBuilder.header(HEADER_VERSION, "13"); // WebSocket's lucky number 122 this.nonce = createNonce(); 123 requestBuilder.header(HEADER_KEY, this.nonce); 124 // Setting request version to HTTP/1.1 forcibly, since it's not possible 125 // to upgrade from HTTP/2 to WebSocket (as of August 2016): 126 // 127 // https://tools.ietf.org/html/draft-hirano-httpbis-websocket-over-http2-00 128 this.request = requestBuilder.version(Version.HTTP_1_1).GET().build(); 129 WebSocketRequest r = (WebSocketRequest) this.request; 130 r.isWebSocket(true); 131 r.setSystemHeader(HEADER_UPGRADE, "websocket"); 132 r.setSystemHeader(HEADER_CONNECTION, "Upgrade"); 133 } 134 135 private static Collection<String> createRequestSubprotocols( 136 Collection<String> subprotocols) 137 { 138 LinkedHashSet<String> sp = new LinkedHashSet<>(subprotocols.size(), 1); 139 for (String s : subprotocols) { 140 if (s.trim().isEmpty() || !isValidName(s)) { 141 throw illegal("Bad subprotocol syntax: " + s); 142 } 143 if (!sp.add(s)) { 144 throw illegal("Duplicating subprotocol: " + s); 145 } 146 } 147 return Collections.unmodifiableCollection(sp); 148 } 149 150 /* 151 * Checks the given URI for being a WebSocket URI and translates it into a 152 * target HTTP URI for the Opening Handshake. 153 * 154 * https://tools.ietf.org/html/rfc6455#section-3 155 */ 156 private static URI createRequestURI(URI uri) { 157 // TODO: check permission for WebSocket URI and translate it into 158 // http/https permission 159 String s = uri.getScheme(); // The scheme might be null (i.e. undefined) 160 if (!("ws".equalsIgnoreCase(s) || "wss".equalsIgnoreCase(s)) 161 || uri.getFragment() != null) 162 { 163 throw illegal("Bad URI: " + uri); 164 } 165 String scheme = "ws".equalsIgnoreCase(s) ? "http" : "https"; 166 try { 167 return new URI(scheme, 168 uri.getUserInfo(), 169 uri.getHost(), 170 uri.getPort(), 171 uri.getPath(), 172 uri.getQuery(), 173 null); // No fragment 174 } catch (URISyntaxException e) { 175 // Shouldn't happen: URI invariant 176 throw new InternalError(e); 177 } 178 } 179 180 CompletableFuture<Result> send() { 181 return client.sendAsync(this.request, BodyHandler.<Void>discard(null)) 182 .thenCompose(this::resultFrom); 183 } 184 185 /* 186 * The result of the opening handshake. 187 */ 188 static final class Result { 189 190 final String subprotocol; 191 final RawChannel channel; 192 193 private Result(String subprotocol, RawChannel channel) { 194 this.subprotocol = subprotocol; 195 this.channel = channel; 196 } 197 } 198 199 private CompletableFuture<Result> resultFrom(HttpResponse<?> response) { 200 // Do we need a special treatment for SSLHandshakeException? 201 // Namely, invoking 202 // 203 // Listener.onClose(StatusCodes.TLS_HANDSHAKE_FAILURE, "") 204 // 205 // See https://tools.ietf.org/html/rfc6455#section-7.4.1 206 Result result = null; 207 Exception exception = null; 208 try { 209 result = handleResponse(response); 210 } catch (IOException e) { 211 exception = e; 212 } catch (Exception e) { 213 exception = new WebSocketHandshakeException(response).initCause(e); 214 } 215 if (exception == null) { 216 return MinimalFuture.completedFuture(result); 217 } 218 try { 219 ((RawChannel.Provider) response).rawChannel().close(); 220 } catch (IOException e) { 221 exception.addSuppressed(e); 222 } 223 return MinimalFuture.failedFuture(exception); 224 } 225 226 private Result handleResponse(HttpResponse<?> response) throws IOException { 227 // By this point all redirects, authentications, etc. (if any) MUST have 228 // been done by the HttpClient used by the WebSocket; so only 101 is 229 // expected 230 int c = response.statusCode(); 231 if (c != 101) { 232 throw checkFailed("Unexpected HTTP response status code " + c); 233 } 234 HttpHeaders headers = response.headers(); 235 String upgrade = requireSingle(headers, HEADER_UPGRADE); 236 if (!upgrade.equalsIgnoreCase("websocket")) { 237 throw checkFailed("Bad response field: " + HEADER_UPGRADE); 238 } 239 String connection = requireSingle(headers, HEADER_CONNECTION); 240 if (!connection.equalsIgnoreCase("Upgrade")) { 241 throw checkFailed("Bad response field: " + HEADER_CONNECTION); 242 } 243 requireAbsent(headers, HEADER_VERSION); 244 requireAbsent(headers, HEADER_EXTENSIONS); 245 String x = this.nonce + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 246 this.sha1.update(x.getBytes(StandardCharsets.ISO_8859_1)); 247 String expected = Base64.getEncoder().encodeToString(this.sha1.digest()); 248 String actual = requireSingle(headers, HEADER_ACCEPT); 249 if (!actual.trim().equals(expected)) { 250 throw checkFailed("Bad " + HEADER_ACCEPT); 251 } 252 String subprotocol = checkAndReturnSubprotocol(headers); 253 RawChannel channel = ((RawChannel.Provider) response).rawChannel(); 254 return new Result(subprotocol, channel); 255 } 256 257 private String checkAndReturnSubprotocol(HttpHeaders responseHeaders) 258 throws CheckFailedException 259 { 260 Optional<String> opt = responseHeaders.firstValue(HEADER_PROTOCOL); 261 if (!opt.isPresent()) { 262 // If there is no such header in the response, then the server 263 // doesn't want to use any subprotocol 264 return ""; 265 } 266 String s = requireSingle(responseHeaders, HEADER_PROTOCOL); 267 // An empty string as a subprotocol's name is not allowed by the spec 268 // and the check below will detect such responses too 269 if (this.subprotocols.contains(s)) { 270 return s; 271 } else { 272 throw checkFailed("Unexpected subprotocol: " + s); 273 } 274 } 275 276 private static void requireAbsent(HttpHeaders responseHeaders, 277 String headerName) 278 { 279 List<String> values = responseHeaders.allValues(headerName); 280 if (!values.isEmpty()) { 281 throw checkFailed(format("Response field '%s' present: %s", 282 headerName, 283 stringOf(values))); 284 } 285 } 286 287 private static String requireSingle(HttpHeaders responseHeaders, 288 String headerName) 289 { 290 List<String> values = responseHeaders.allValues(headerName); 291 if (values.isEmpty()) { 292 throw checkFailed("Response field missing: " + headerName); 293 } else if (values.size() > 1) { 294 throw checkFailed(format("Response field '%s' multivalued: %s", 295 headerName, 296 stringOf(values))); 297 } 298 return values.get(0); 299 } 300 301 private static String createNonce() { 302 byte[] bytes = new byte[16]; 303 OpeningHandshake.srandom.nextBytes(bytes); 304 return Base64.getEncoder().encodeToString(bytes); 305 } 306 307 private static IllegalArgumentException illegal(String message) { 308 return new IllegalArgumentException(message); 309 } 310 311 private static CheckFailedException checkFailed(String message) { 312 throw new CheckFailedException(message); 313 } 314 }