1 /*
   2  * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import java.io.Closeable;
  25 import java.io.IOException;
  26 import java.io.UncheckedIOException;
  27 import java.net.InetAddress;
  28 import java.net.InetSocketAddress;
  29 import java.net.StandardSocketOptions;
  30 import java.net.URI;
  31 import java.nio.ByteBuffer;
  32 import java.nio.CharBuffer;
  33 import java.nio.channels.ClosedByInterruptException;
  34 import java.nio.channels.ServerSocketChannel;
  35 import java.nio.channels.SocketChannel;
  36 import java.nio.charset.CharacterCodingException;
  37 import java.security.MessageDigest;
  38 import java.security.NoSuchAlgorithmException;
  39 import java.util.ArrayList;
  40 import java.util.Arrays;
  41 import java.util.Base64;
  42 import java.util.HashMap;
  43 import java.util.Iterator;
  44 import java.util.LinkedList;
  45 import java.util.List;
  46 import java.util.Map;
  47 import java.util.concurrent.CountDownLatch;
  48 import java.util.concurrent.atomic.AtomicBoolean;
  49 import java.util.function.BiFunction;
  50 import java.util.regex.Pattern;
  51 import java.util.stream.Collectors;
  52 
  53 import static java.lang.String.format;
  54 import static java.lang.System.err;
  55 import static java.nio.charset.StandardCharsets.ISO_8859_1;
  56 import static java.nio.charset.StandardCharsets.UTF_8;
  57 import static java.util.Arrays.asList;
  58 import static java.util.Objects.requireNonNull;
  59 
  60 /**
  61  * Dummy WebSocket Server.
  62  *
  63  * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e.
  64  * no proxying, cookies, etc.) Supports sequential connections, one at a time,
  65  * i.e. in order for a client to connect to the server the previous client must
  66  * disconnect first.
  67  *
  68  * Expected client request:
  69  *
  70  *     GET /chat HTTP/1.1
  71  *     Host: server.example.com
  72  *     Upgrade: websocket
  73  *     Connection: Upgrade
  74  *     Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
  75  *     Origin: http://example.com
  76  *     Sec-WebSocket-Protocol: chat, superchat
  77  *     Sec-WebSocket-Version: 13
  78  *
  79  * This server response:
  80  *
  81  *     HTTP/1.1 101 Switching Protocols
  82  *     Upgrade: websocket
  83  *     Connection: Upgrade
  84  *     Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
  85  *     Sec-WebSocket-Protocol: chat
  86  */
  87 public class DummyWebSocketServer implements Closeable {
  88 
  89     private final AtomicBoolean started = new AtomicBoolean();
  90     private final Thread thread;
  91     private volatile ServerSocketChannel ssc;
  92     private volatile InetSocketAddress address;
  93     private ByteBuffer read = ByteBuffer.allocate(16384);
  94     private final CountDownLatch readReady = new CountDownLatch(1);
  95 
  96     private static class Credentials {
  97         private final String name;
  98         private final String password;
  99         private Credentials(String name, String password) {
 100             this.name = name;
 101             this.password = password;
 102         }
 103         public String name() { return name; }
 104         public String password() { return password; }
 105     }
 106 
 107     public DummyWebSocketServer() {
 108         this(defaultMapping(), null, null);
 109     }
 110 
 111     public DummyWebSocketServer(String username, String password) {
 112         this(defaultMapping(), username, password);
 113     }
 114 
 115     public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
 116                                 String username,
 117                                 String password) {
 118         requireNonNull(mapping);
 119         Credentials credentials = username != null ?
 120                 new Credentials(username, password) : null;
 121 
 122         thread = new Thread(() -> {
 123             try {
 124                 while (!Thread.currentThread().isInterrupted()) {
 125                     err.println("Accepting next connection at: " + ssc);
 126                     SocketChannel channel = ssc.accept();
 127                     err.println("Accepted: " + channel);
 128                     try {
 129                         channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
 130                         channel.configureBlocking(true);
 131                         while (true) {
 132                             StringBuilder request = new StringBuilder();
 133                             if (!readRequest(channel, request)) {
 134                                 throw new IOException("Bad request:[" + request + "]");
 135                             }
 136                             List<String> strings = asList(request.toString().split("\r\n"));
 137                             List<String> response = mapping.apply(strings, credentials);
 138                             writeResponse(channel, response);
 139 
 140                             if (response.get(0).startsWith("HTTP/1.1 401")) {
 141                                 err.println("Sent 401 Authentication response " + channel);
 142                                 continue;
 143                             } else {
 144                                 serve(channel);
 145                                 break;
 146                             }
 147                         }
 148                     } catch (IOException e) {
 149                         err.println("Error in connection: " + channel + ", " + e);
 150                     } finally {
 151                         err.println("Closed: " + channel);
 152                         close(channel);
 153                         readReady.countDown();
 154                     }
 155                 }
 156             } catch (ClosedByInterruptException ignored) {
 157             } catch (Exception e) {
 158                 e.printStackTrace(err);
 159             } finally {
 160                 close(ssc);
 161                 err.println("Stopped at: " + getURI());
 162             }
 163         });
 164         thread.setName("DummyWebSocketServer");
 165         thread.setDaemon(false);
 166     }
 167 
 168     protected void read(SocketChannel ch) throws IOException {
 169         // Read until the thread is interrupted or an error occurred
 170         // or the input is shutdown
 171         ByteBuffer b = ByteBuffer.allocate(65536);
 172         while (ch.read(b) != -1) {
 173             b.flip();
 174             if (read.remaining() < b.remaining()) {
 175                 int required = read.capacity() - read.remaining() + b.remaining();
 176                 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1);
 177                 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required);
 178                 newBuffer.put(read.flip());
 179                 read = newBuffer;
 180             }
 181             read.put(b);
 182             b.clear();
 183         }
 184     }
 185 
 186     protected void write(SocketChannel ch) throws IOException { }
 187 
 188     protected final void serve(SocketChannel channel)
 189             throws InterruptedException
 190     {
 191         Thread reader = new Thread(() -> {
 192             try {
 193                 read(channel);
 194             } catch (IOException ignored) { }
 195         });
 196         Thread writer = new Thread(() -> {
 197             try {
 198                 write(channel);
 199             } catch (IOException ignored) { }
 200         });
 201         reader.start();
 202         writer.start();
 203         try {
 204             reader.join();
 205         } finally {
 206             reader.interrupt();
 207             try {
 208                 writer.join();
 209             } finally {
 210                 writer.interrupt();
 211             }
 212         }
 213     }
 214 
 215     public ByteBuffer read() throws InterruptedException {
 216         readReady.await();
 217         return read.duplicate().asReadOnlyBuffer().flip();
 218     }
 219 
 220     public void open() throws IOException {
 221         err.println("Starting");
 222         if (!started.compareAndSet(false, true)) {
 223             throw new IllegalStateException("Already started");
 224         }
 225         ssc = ServerSocketChannel.open();
 226         try {
 227             ssc.configureBlocking(true);
 228             ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
 229             address = (InetSocketAddress) ssc.getLocalAddress();
 230             thread.start();
 231         } catch (IOException e) {
 232             close(ssc);
 233             throw e;
 234         }
 235         err.println("Started at: " + getURI());
 236     }
 237 
 238     @Override
 239     public void close() {
 240         err.println("Stopping: " + getURI());
 241         thread.interrupt();
 242         close(ssc);
 243     }
 244 
 245     URI getURI() {
 246         if (!started.get()) {
 247             throw new IllegalStateException("Not yet started");
 248         }
 249         return URI.create("ws://localhost:" + address.getPort());
 250     }
 251 
 252     private boolean readRequest(SocketChannel channel, StringBuilder request)
 253             throws IOException
 254     {
 255         ByteBuffer buffer = ByteBuffer.allocate(512);
 256         while (channel.read(buffer) != -1) {
 257             // read the complete HTTP request headers, there should be no body
 258             CharBuffer decoded;
 259             buffer.flip();
 260             try {
 261                 decoded = ISO_8859_1.newDecoder().decode(buffer);
 262             } catch (CharacterCodingException e) {
 263                 throw new UncheckedIOException(e);
 264             }
 265             request.append(decoded);
 266             if (Pattern.compile("\r\n\r\n").matcher(request).find())
 267                 return true;
 268             buffer.clear();
 269         }
 270         return false;
 271     }
 272 
 273     private void writeResponse(SocketChannel channel, List<String> response)
 274             throws IOException
 275     {
 276         String s = response.stream().collect(Collectors.joining("\r\n"))
 277                 + "\r\n\r\n";
 278         ByteBuffer encoded;
 279         try {
 280             encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
 281         } catch (CharacterCodingException e) {
 282             throw new UncheckedIOException(e);
 283         }
 284         while (encoded.hasRemaining()) {
 285             channel.write(encoded);
 286         }
 287     }
 288 
 289     private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
 290         return (request, credentials) -> {
 291             List<String> response = new LinkedList<>();
 292             Iterator<String> iterator = request.iterator();
 293             if (!iterator.hasNext()) {
 294                 throw new IllegalStateException("The request is empty");
 295             }
 296             String statusLine = iterator.next();
 297             if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
 298                 throw new IllegalStateException
 299                         ("Unexpected status line: " + request.get(0));
 300             }
 301             response.add("HTTP/1.1 101 Switching Protocols");
 302             Map<String, List<String>> requestHeaders = new HashMap<>();
 303             while (iterator.hasNext()) {
 304                 String header = iterator.next();
 305                 String[] split = header.split(": ");
 306                 if (split.length != 2) {
 307                     throw new IllegalStateException
 308                             ("Unexpected header: " + header
 309                                      + ", split=" + Arrays.toString(split));
 310                 }
 311                 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
 312 
 313             }
 314             if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
 315                 throw new IllegalStateException("Subprotocols are not expected");
 316             }
 317             if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
 318                 throw new IllegalStateException("Extensions are not expected");
 319             }
 320             expectHeader(requestHeaders, "Connection", "Upgrade");
 321             response.add("Connection: Upgrade");
 322             expectHeader(requestHeaders, "Upgrade", "websocket");
 323             response.add("Upgrade: websocket");
 324             expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
 325             List<String> key = requestHeaders.get("Sec-WebSocket-Key");
 326             if (key == null || key.isEmpty()) {
 327                 throw new IllegalStateException("Sec-WebSocket-Key is missing");
 328             }
 329             if (key.size() != 1) {
 330                 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
 331             }
 332             MessageDigest sha1 = null;
 333             try {
 334                 sha1 = MessageDigest.getInstance("SHA-1");
 335             } catch (NoSuchAlgorithmException e) {
 336                 throw new InternalError(e);
 337             }
 338             String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
 339             sha1.update(x.getBytes(ISO_8859_1));
 340             String v = Base64.getEncoder().encodeToString(sha1.digest());
 341             response.add("Sec-WebSocket-Accept: " + v);
 342 
 343             // check authorization credentials, if required by the server
 344             if (credentials != null && !authorized(credentials, requestHeaders)) {
 345                 response.clear();
 346                 response.add("HTTP/1.1 401 Unauthorized");
 347                 response.add("Content-Length: 0");
 348                 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
 349             }
 350 
 351             return response;
 352         };
 353     }
 354 
 355     // Checks credentials in the request against those allowable by the server.
 356     private static boolean authorized(Credentials credentials,
 357                                       Map<String,List<String>> requestHeaders) {
 358         List<String> authorization = requestHeaders.get("Authorization");
 359         if (authorization == null)
 360             return false;
 361 
 362         if (authorization.size() != 1) {
 363             throw new IllegalStateException("Authorization unexpected count:" + authorization);
 364         }
 365         String header = authorization.get(0);
 366         if (!header.startsWith("Basic "))
 367             throw new IllegalStateException("Authorization not Basic: " + header);
 368 
 369         header = header.substring("Basic ".length());
 370         String values = new String(Base64.getDecoder().decode(header), UTF_8);
 371         int sep = values.indexOf(':');
 372         if (sep < 1) {
 373             throw new IllegalStateException("Authorization not colon: " +  values);
 374         }
 375         String name = values.substring(0, sep);
 376         String password = values.substring(sep + 1);
 377 
 378         if (name.equals(credentials.name()) && password.equals(credentials.password()))
 379             return true;
 380 
 381         return false;
 382     }
 383 
 384     protected static String expectHeader(Map<String, List<String>> headers,
 385                                          String name,
 386                                          String value) {
 387         List<String> v = headers.get(name);
 388         if (v == null) {
 389             throw new IllegalStateException(
 390                     format("Expected '%s' header, not present in %s",
 391                            name, headers));
 392         }
 393         if (!v.contains(value)) {
 394             throw new IllegalStateException(
 395                     format("Expected '%s: %s', actual: '%s: %s'",
 396                            name, value, name, v)
 397             );
 398         }
 399         return value;
 400     }
 401 
 402     private static void close(AutoCloseable... acs) {
 403         for (AutoCloseable ac : acs) {
 404             try {
 405                 ac.close();
 406             } catch (Exception ignored) { }
 407         }
 408     }
 409 }