1 /* 2 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import javax.net.ServerSocketFactory; 25 import javax.net.ssl.SSLServerSocketFactory; 26 import java.io.Closeable; 27 import java.io.IOException; 28 import java.io.InputStream; 29 import java.io.OutputStream; 30 import java.io.UncheckedIOException; 31 import java.net.InetAddress; 32 import java.net.InetSocketAddress; 33 import java.net.Socket; 34 import java.net.ServerSocket; 35 import java.net.SocketAddress; 36 import java.net.SocketOption; 37 import java.net.StandardSocketOptions; 38 import java.net.URI; 39 import java.nio.ByteBuffer; 40 import java.nio.CharBuffer; 41 import java.nio.channels.ClosedByInterruptException; 42 import java.nio.channels.ServerSocketChannel; 43 import java.nio.channels.SocketChannel; 44 import java.nio.charset.CharacterCodingException; 45 import java.security.MessageDigest; 46 import java.security.NoSuchAlgorithmException; 47 import java.util.ArrayList; 48 import java.util.Arrays; 49 import java.util.Base64; 50 import java.util.HashMap; 51 import java.util.Iterator; 52 import java.util.LinkedList; 53 import java.util.List; 54 import java.util.Map; 55 import java.util.concurrent.CountDownLatch; 56 import java.util.concurrent.atomic.AtomicBoolean; 57 import java.util.function.BiFunction; 58 import java.util.regex.Pattern; 59 import java.util.stream.Collectors; 60 61 import static java.lang.String.format; 62 import static java.lang.System.err; 63 import static java.nio.charset.StandardCharsets.ISO_8859_1; 64 import static java.nio.charset.StandardCharsets.UTF_8; 65 import static java.util.Arrays.asList; 66 import static java.util.Objects.requireNonNull; 67 68 /** 69 * Dummy WebSocket Server, which supports TLS. 70 * By default the dummy webserver uses a plain TCP connection, 71 * but it can use a TLS connection if secure() is called before 72 * open(). It will use the default SSL context. 73 * 74 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e. 75 * no proxying, cookies, etc.) Supports sequential connections, one at a time, 76 * i.e. in order for a client to connect to the server the previous client must 77 * disconnect first. 78 * 79 * Expected client request: 80 * 81 * GET /chat HTTP/1.1 82 * Host: server.example.com 83 * Upgrade: websocket 84 * Connection: Upgrade 85 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== 86 * Origin: http://example.com 87 * Sec-WebSocket-Protocol: chat, superchat 88 * Sec-WebSocket-Version: 13 89 * 90 * This server response: 91 * 92 * HTTP/1.1 101 Switching Protocols 93 * Upgrade: websocket 94 * Connection: Upgrade 95 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 96 * Sec-WebSocket-Protocol: chat 97 */ 98 public class DummySecureWebSocketServer implements Closeable { 99 100 /** 101 * Emulates some of the SocketChannel APIs over a Socket 102 * instance. 103 */ 104 public static class WebSocketChannel implements AutoCloseable { 105 interface Reader { 106 int read(ByteBuffer buf) throws IOException; 107 } 108 interface Writer { 109 void write(ByteBuffer buf) throws IOException; 110 } 111 interface Config { 112 <T> void setOption(SocketOption<T> option, T value) throws IOException; 113 } 114 interface Closer { 115 void close() throws IOException; 116 } 117 final AutoCloseable channel; 118 final Reader reader; 119 final Writer writer; 120 final Config config; 121 final Closer closer; 122 WebSocketChannel(AutoCloseable channel, Reader reader, Writer writer, Config config, Closer closer) { 123 this.channel = channel; 124 this.reader = reader; 125 this.writer = writer; 126 this.config = config; 127 this.closer = closer; 128 } 129 public void close() throws IOException { 130 closer.close(); 131 } 132 public String toString() { 133 return channel.toString(); 134 } 135 public int read(ByteBuffer bb) throws IOException { 136 return reader.read(bb); 137 } 138 public void write(ByteBuffer bb) throws IOException { 139 writer.write(bb); 140 } 141 public <T> void setOption(SocketOption<T> option, T value) throws IOException { 142 config.setOption(option, value); 143 } 144 public static WebSocketChannel of(Socket s) { 145 Reader reader = (bb) -> DummySecureWebSocketServer.read(s.getInputStream(), bb); 146 Writer writer = (bb) -> DummySecureWebSocketServer.write(s.getOutputStream(), bb); 147 return new WebSocketChannel(s, reader, writer, s::setOption, s::close); 148 } 149 } 150 151 /** 152 * Emulates some of the ServerSocketChannel APIs over a ServerSocket 153 * instance. 154 */ 155 public static class WebServerSocketChannel implements AutoCloseable { 156 interface Accepter { 157 WebSocketChannel accept() throws IOException; 158 } 159 interface Binder { 160 void bind(SocketAddress address) throws IOException; 161 } 162 interface Config { 163 <T> void setOption(SocketOption<T> option, T value) throws IOException; 164 } 165 interface Closer { 166 void close() throws IOException; 167 } 168 interface Addressable { 169 SocketAddress getLocalAddress() throws IOException; 170 } 171 final AutoCloseable server; 172 final Accepter accepter; 173 final Binder binder; 174 final Addressable address; 175 final Config config; 176 final Closer closer; 177 WebServerSocketChannel(AutoCloseable server, 178 Accepter accepter, 179 Binder binder, 180 Addressable address, 181 Config config, 182 Closer closer) { 183 this.server = server; 184 this.accepter = accepter; 185 this.binder = binder; 186 this.address = address; 187 this.config = config; 188 this.closer = closer; 189 } 190 public void close() throws IOException { 191 closer.close(); 192 } 193 public String toString() { 194 return server.toString(); 195 } 196 public WebSocketChannel accept() throws IOException { 197 return accepter.accept(); 198 } 199 public void bind(SocketAddress address) throws IOException { 200 binder.bind(address); 201 } 202 public <T> void setOption(SocketOption<T> option, T value) throws IOException { 203 config.setOption(option, value); 204 } 205 public SocketAddress getLocalAddress() throws IOException { 206 return address.getLocalAddress(); 207 } 208 public static WebServerSocketChannel of(ServerSocket ss) { 209 Accepter a = () -> WebSocketChannel.of(ss.accept()); 210 return new WebServerSocketChannel(ss, a, ss::bind, ss::getLocalSocketAddress, ss::setOption, ss::close); 211 } 212 } 213 214 // Creates a secure WebServerSocketChannel 215 static WebServerSocketChannel openWSS() throws IOException { 216 return WebServerSocketChannel.of(SSLServerSocketFactory.getDefault().createServerSocket()); 217 } 218 219 // Creates a plain WebServerSocketChannel 220 static WebServerSocketChannel openWS() throws IOException { 221 return WebServerSocketChannel.of(ServerSocketFactory.getDefault().createServerSocket()); 222 } 223 224 225 static int read(InputStream str, ByteBuffer buffer) throws IOException { 226 int len = Math.min(buffer.remaining(), 1024); 227 if (len <= 0) return 0; 228 byte[] bytes = new byte[len]; 229 int res = 0; 230 if (buffer.hasRemaining()) { 231 len = Math.min(len, buffer.remaining()); 232 int n = str.read(bytes, 0, len); 233 if (n > 0) { 234 buffer.put(bytes, 0, n); 235 res += n; 236 } else if (res > 0) { 237 return res; 238 } else { 239 return n; 240 } 241 } 242 return res; 243 } 244 245 static void write(OutputStream str, ByteBuffer buffer) throws IOException { 246 int len = Math.min(buffer.remaining(), 1024); 247 if (len <= 0) return; 248 byte[] bytes = new byte[len]; 249 int res = 0; 250 int pos = buffer.position(); 251 while (buffer.hasRemaining()) { 252 len = Math.min(len, buffer.remaining()); 253 buffer.get(bytes, 0, len); 254 str.write(bytes, 0, len); 255 } 256 } 257 258 private final AtomicBoolean started = new AtomicBoolean(); 259 private final Thread thread; 260 private volatile WebServerSocketChannel ss; 261 private volatile InetSocketAddress address; 262 private volatile boolean secure; 263 private ByteBuffer read = ByteBuffer.allocate(16384); 264 private final CountDownLatch readReady = new CountDownLatch(1); 265 private volatile boolean done; 266 267 private static class Credentials { 268 private final String name; 269 private final String password; 270 private Credentials(String name, String password) { 271 this.name = name; 272 this.password = password; 273 } 274 public String name() { return name; } 275 public String password() { return password; } 276 } 277 278 public DummySecureWebSocketServer() { 279 this(defaultMapping(), null, null); 280 } 281 282 public DummySecureWebSocketServer(String username, String password) { 283 this(defaultMapping(), username, password); 284 } 285 286 public DummySecureWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, 287 String username, 288 String password) { 289 requireNonNull(mapping); 290 Credentials credentials = username != null ? 291 new Credentials(username, password) : null; 292 293 thread = new Thread(() -> { 294 try { 295 while (!Thread.currentThread().isInterrupted() && !done) { 296 err.println("Accepting next connection at: " + ss); 297 WebSocketChannel channel = ss.accept(); 298 err.println("Accepted: " + channel); 299 try { 300 channel.setOption(StandardSocketOptions.TCP_NODELAY, true); 301 while (!done) { 302 StringBuilder request = new StringBuilder(); 303 if (!readRequest(channel, request)) { 304 throw new IOException("Bad request:[" + request + "]"); 305 } 306 List<String> strings = asList(request.toString().split("\r\n")); 307 List<String> response = mapping.apply(strings, credentials); 308 writeResponse(channel, response); 309 310 if (response.get(0).startsWith("HTTP/1.1 401")) { 311 err.println("Sent 401 Authentication response " + channel); 312 continue; 313 } else { 314 serve(channel); 315 break; 316 } 317 } 318 } catch (IOException e) { 319 if (!done) { 320 err.println("Error in connection: " + channel + ", " + e); 321 } 322 } finally { 323 err.println("Closed: " + channel); 324 close(channel); 325 readReady.countDown(); 326 } 327 } 328 } catch (ClosedByInterruptException ignored) { 329 } catch (Throwable e) { 330 if (!done) { 331 e.printStackTrace(err); 332 } 333 } finally { 334 done = true; 335 close(ss); 336 err.println("Stopped at: " + getURI()); 337 } 338 }); 339 thread.setName("DummySecureWebSocketServer"); 340 thread.setDaemon(false); 341 } 342 343 // must be called before open() 344 public DummySecureWebSocketServer secure() { 345 secure = true; 346 return this; 347 } 348 349 protected void read(WebSocketChannel ch) throws IOException { 350 // Read until the thread is interrupted or an error occurred 351 // or the input is shutdown 352 ByteBuffer b = ByteBuffer.allocate(65536); 353 while (ch.read(b) != -1) { 354 b.flip(); 355 if (read.remaining() < b.remaining()) { 356 int required = read.capacity() - read.remaining() + b.remaining(); 357 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1); 358 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required); 359 newBuffer.put(read.flip()); 360 read = newBuffer; 361 } 362 read.put(b); 363 b.clear(); 364 } 365 } 366 367 protected void write(WebSocketChannel ch) throws IOException { } 368 369 protected final void serve(WebSocketChannel channel) 370 throws InterruptedException 371 { 372 Thread reader = new Thread(() -> { 373 try { 374 read(channel); 375 } catch (IOException ignored) { } 376 }); 377 Thread writer = new Thread(() -> { 378 try { 379 write(channel); 380 } catch (IOException ignored) { } 381 }); 382 reader.start(); 383 writer.start(); 384 try { 385 while (!done) { 386 try { 387 reader.join(500); 388 } catch (InterruptedException x) { 389 if (done) { 390 close(channel); 391 break; 392 } 393 } 394 } 395 } finally { 396 reader.interrupt(); 397 try { 398 while (!done) { 399 try { 400 writer.join(500); 401 } catch (InterruptedException x) { 402 if (done) break; 403 } 404 } 405 } finally { 406 writer.interrupt(); 407 } 408 } 409 } 410 411 public ByteBuffer read() throws InterruptedException { 412 readReady.await(); 413 return read.duplicate().asReadOnlyBuffer().flip(); 414 } 415 416 public void open() throws IOException { 417 err.println("Starting"); 418 if (!started.compareAndSet(false, true)) { 419 throw new IllegalStateException("Already started"); 420 } 421 ss = secure ? openWSS() : openWS(); 422 try { 423 ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); 424 address = (InetSocketAddress) ss.getLocalAddress(); 425 thread.start(); 426 } catch (IOException e) { 427 done = true; 428 close(ss); 429 throw e; 430 } 431 err.println("Started at: " + getURI()); 432 } 433 434 @Override 435 public void close() { 436 err.println("Stopping: " + getURI()); 437 done = true; 438 thread.interrupt(); 439 close(ss); 440 } 441 442 URI getURI() { 443 if (!started.get()) { 444 throw new IllegalStateException("Not yet started"); 445 } 446 if (!secure) { 447 return URI.create("ws://localhost:" + address.getPort()); 448 } else { 449 return URI.create("wss://localhost:" + address.getPort()); 450 } 451 } 452 453 private boolean readRequest(WebSocketChannel channel, StringBuilder request) 454 throws IOException 455 { 456 ByteBuffer buffer = ByteBuffer.allocate(512); 457 while (channel.read(buffer) != -1) { 458 // read the complete HTTP request headers, there should be no body 459 CharBuffer decoded; 460 buffer.flip(); 461 try { 462 decoded = ISO_8859_1.newDecoder().decode(buffer); 463 } catch (CharacterCodingException e) { 464 throw new UncheckedIOException(e); 465 } 466 request.append(decoded); 467 if (Pattern.compile("\r\n\r\n").matcher(request).find()) 468 return true; 469 buffer.clear(); 470 } 471 return false; 472 } 473 474 private void writeResponse(WebSocketChannel channel, List<String> response) 475 throws IOException 476 { 477 String s = response.stream().collect(Collectors.joining("\r\n")) 478 + "\r\n\r\n"; 479 ByteBuffer encoded; 480 try { 481 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); 482 } catch (CharacterCodingException e) { 483 throw new UncheckedIOException(e); 484 } 485 while (encoded.hasRemaining()) { 486 channel.write(encoded); 487 } 488 } 489 490 private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() { 491 return (request, credentials) -> { 492 List<String> response = new LinkedList<>(); 493 Iterator<String> iterator = request.iterator(); 494 if (!iterator.hasNext()) { 495 throw new IllegalStateException("The request is empty"); 496 } 497 String statusLine = iterator.next(); 498 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) { 499 throw new IllegalStateException 500 ("Unexpected status line: " + request.get(0)); 501 } 502 response.add("HTTP/1.1 101 Switching Protocols"); 503 Map<String, List<String>> requestHeaders = new HashMap<>(); 504 while (iterator.hasNext()) { 505 String header = iterator.next(); 506 String[] split = header.split(": "); 507 if (split.length != 2) { 508 throw new IllegalStateException 509 ("Unexpected header: " + header 510 + ", split=" + Arrays.toString(split)); 511 } 512 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]); 513 514 } 515 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { 516 throw new IllegalStateException("Subprotocols are not expected"); 517 } 518 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { 519 throw new IllegalStateException("Extensions are not expected"); 520 } 521 expectHeader(requestHeaders, "Connection", "Upgrade"); 522 response.add("Connection: Upgrade"); 523 expectHeader(requestHeaders, "Upgrade", "websocket"); 524 response.add("Upgrade: websocket"); 525 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); 526 List<String> key = requestHeaders.get("Sec-WebSocket-Key"); 527 if (key == null || key.isEmpty()) { 528 throw new IllegalStateException("Sec-WebSocket-Key is missing"); 529 } 530 if (key.size() != 1) { 531 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key); 532 } 533 MessageDigest sha1 = null; 534 try { 535 sha1 = MessageDigest.getInstance("SHA-1"); 536 } catch (NoSuchAlgorithmException e) { 537 throw new InternalError(e); 538 } 539 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 540 sha1.update(x.getBytes(ISO_8859_1)); 541 String v = Base64.getEncoder().encodeToString(sha1.digest()); 542 response.add("Sec-WebSocket-Accept: " + v); 543 544 // check authorization credentials, if required by the server 545 if (credentials != null && !authorized(credentials, requestHeaders)) { 546 response.clear(); 547 response.add("HTTP/1.1 401 Unauthorized"); 548 response.add("Content-Length: 0"); 549 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\""); 550 } 551 552 return response; 553 }; 554 } 555 556 // Checks credentials in the request against those allowable by the server. 557 private static boolean authorized(Credentials credentials, 558 Map<String,List<String>> requestHeaders) { 559 List<String> authorization = requestHeaders.get("Authorization"); 560 if (authorization == null) 561 return false; 562 563 if (authorization.size() != 1) { 564 throw new IllegalStateException("Authorization unexpected count:" + authorization); 565 } 566 String header = authorization.get(0); 567 if (!header.startsWith("Basic ")) 568 throw new IllegalStateException("Authorization not Basic: " + header); 569 570 header = header.substring("Basic ".length()); 571 String values = new String(Base64.getDecoder().decode(header), UTF_8); 572 int sep = values.indexOf(':'); 573 if (sep < 1) { 574 throw new IllegalStateException("Authorization not colon: " + values); 575 } 576 String name = values.substring(0, sep); 577 String password = values.substring(sep + 1); 578 579 if (name.equals(credentials.name()) && password.equals(credentials.password())) 580 return true; 581 582 return false; 583 } 584 585 protected static String expectHeader(Map<String, List<String>> headers, 586 String name, 587 String value) { 588 List<String> v = headers.get(name); 589 if (v == null) { 590 throw new IllegalStateException( 591 format("Expected '%s' header, not present in %s", 592 name, headers)); 593 } 594 if (!v.contains(value)) { 595 throw new IllegalStateException( 596 format("Expected '%s: %s', actual: '%s: %s'", 597 name, value, name, v) 598 ); 599 } 600 return value; 601 } 602 603 private static void close(AutoCloseable... acs) { 604 for (AutoCloseable ac : acs) { 605 try { 606 ac.close(); 607 } catch (Exception ignored) { } 608 } 609 } 610 }