1 /*
   2  * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import java.io.Closeable;
  25 import java.io.IOException;
  26 import java.io.UncheckedIOException;
  27 import java.net.InetAddress;
  28 import java.net.InetSocketAddress;
  29 import java.net.StandardSocketOptions;
  30 import java.net.URI;
  31 import java.nio.ByteBuffer;
  32 import java.nio.CharBuffer;
  33 import java.nio.channels.ClosedByInterruptException;
  34 import java.nio.channels.ServerSocketChannel;
  35 import java.nio.channels.SocketChannel;
  36 import java.nio.charset.CharacterCodingException;
  37 import java.security.MessageDigest;
  38 import java.security.NoSuchAlgorithmException;
  39 import java.util.ArrayList;
  40 import java.util.Arrays;
  41 import java.util.Base64;
  42 import java.util.HashMap;
  43 import java.util.Iterator;
  44 import java.util.LinkedList;
  45 import java.util.List;
  46 import java.util.Map;
  47 import java.util.concurrent.CountDownLatch;
  48 import java.util.concurrent.atomic.AtomicBoolean;
  49 import java.util.function.Function;
  50 import java.util.regex.Pattern;
  51 import java.util.stream.Collectors;
  52 
  53 import static java.lang.String.format;
  54 import static java.lang.System.err;
  55 import static java.nio.charset.StandardCharsets.ISO_8859_1;
  56 import static java.util.Arrays.asList;
  57 import static java.util.Objects.requireNonNull;
  58 
  59 /**
  60  * Dummy WebSocket Server.
  61  *
  62  * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e.
  63  * no proxying, cookies, etc.) Supports sequential connections, one at a time,
  64  * i.e. in order for a client to connect to the server the previous client must
  65  * disconnect first.
  66  *
  67  * Expected client request:
  68  *
  69  *     GET /chat HTTP/1.1
  70  *     Host: server.example.com
  71  *     Upgrade: websocket
  72  *     Connection: Upgrade
  73  *     Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
  74  *     Origin: http://example.com
  75  *     Sec-WebSocket-Protocol: chat, superchat
  76  *     Sec-WebSocket-Version: 13
  77  *
  78  * This server response:
  79  *
  80  *     HTTP/1.1 101 Switching Protocols
  81  *     Upgrade: websocket
  82  *     Connection: Upgrade
  83  *     Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
  84  *     Sec-WebSocket-Protocol: chat
  85  */
  86 public class DummyWebSocketServer implements Closeable {
  87 
  88     private final AtomicBoolean started = new AtomicBoolean();
  89     private final Thread thread;
  90     private volatile ServerSocketChannel ssc;
  91     private volatile InetSocketAddress address;
  92     private ByteBuffer read = ByteBuffer.allocate(16384);
  93     private final CountDownLatch readReady = new CountDownLatch(1);
  94 
  95     public DummyWebSocketServer() {
  96         this(defaultMapping());
  97     }
  98 
  99     public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
 100         requireNonNull(mapping);
 101         thread = new Thread(() -> {
 102             try {
 103                 while (!Thread.currentThread().isInterrupted()) {
 104                     err.println("Accepting next connection at: " + ssc);
 105                     SocketChannel channel = ssc.accept();
 106                     err.println("Accepted: " + channel);
 107                     try {
 108                         channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
 109                         channel.configureBlocking(true);
 110                         StringBuilder request = new StringBuilder();
 111                         if (!readRequest(channel, request)) {
 112                             throw new IOException("Bad request:" + request);
 113                         }
 114                         List<String> strings = asList(request.toString().split("\r\n"));
 115                         List<String> response = mapping.apply(strings);
 116                         writeResponse(channel, response);
 117                         serve(channel);
 118                     } catch (IOException e) {
 119                         err.println("Error in connection: " + channel + ", " + e);
 120                     } finally {
 121                         err.println("Closed: " + channel);
 122                         close(channel);
 123                         readReady.countDown();
 124                     }
 125                 }
 126             } catch (ClosedByInterruptException ignored) {
 127             } catch (Exception e) {
 128                 err.println(e);
 129             } finally {
 130                 close(ssc);
 131                 err.println("Stopped at: " + getURI());
 132             }
 133         });
 134         thread.setName("DummyWebSocketServer");
 135         thread.setDaemon(false);
 136     }
 137 
 138     protected void read(SocketChannel ch) throws IOException {
 139         // Read until the thread is interrupted or an error occurred
 140         // or the input is shutdown
 141         ByteBuffer b = ByteBuffer.allocate(65536);
 142         while (ch.read(b) != -1) {
 143             b.flip();
 144             if (read.remaining() < b.remaining()) {
 145                 int required = read.capacity() - read.remaining() + b.remaining();
 146                 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1);
 147                 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required);
 148                 newBuffer.put(read.flip());
 149                 read = newBuffer;
 150             }
 151             read.put(b);
 152             b.clear();
 153         }
 154     }
 155 
 156     protected void write(SocketChannel ch) throws IOException { }
 157 
 158     protected final void serve(SocketChannel channel)
 159             throws InterruptedException
 160     {
 161         Thread reader = new Thread(() -> {
 162             try {
 163                 read(channel);
 164             } catch (IOException ignored) { }
 165         });
 166         Thread writer = new Thread(() -> {
 167             try {
 168                 write(channel);
 169             } catch (IOException ignored) { }
 170         });
 171         reader.start();
 172         writer.start();
 173         try {
 174             reader.join();
 175         } finally {
 176             reader.interrupt();
 177             try {
 178                 writer.join();
 179             } finally {
 180                 writer.interrupt();
 181             }
 182         }
 183     }
 184 
 185     public ByteBuffer read() throws InterruptedException {
 186         readReady.await();
 187         return read.duplicate().asReadOnlyBuffer().flip();
 188     }
 189 
 190     public void open() throws IOException {
 191         err.println("Starting");
 192         if (!started.compareAndSet(false, true)) {
 193             throw new IllegalStateException("Already started");
 194         }
 195         ssc = ServerSocketChannel.open();
 196         try {
 197             ssc.configureBlocking(true);
 198             ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0));
 199             address = (InetSocketAddress) ssc.getLocalAddress();
 200             thread.start();
 201         } catch (IOException e) {
 202             close(ssc);
 203             throw e;
 204         }
 205         err.println("Started at: " + getURI());
 206     }
 207 
 208     @Override
 209     public void close() {
 210         err.println("Stopping: " + getURI());
 211         thread.interrupt();
 212         close(ssc);
 213     }
 214 
 215     URI getURI() {
 216         if (!started.get()) {
 217             throw new IllegalStateException("Not yet started");
 218         }
 219         return URI.create("ws://localhost:" + address.getPort());
 220     }
 221 
 222     private boolean readRequest(SocketChannel channel, StringBuilder request)
 223             throws IOException
 224     {
 225         ByteBuffer buffer = ByteBuffer.allocate(512);
 226         while (channel.read(buffer) != -1) {
 227             // read the complete HTTP request headers, there should be no body
 228             CharBuffer decoded;
 229             buffer.flip();
 230             try {
 231                 decoded = ISO_8859_1.newDecoder().decode(buffer);
 232             } catch (CharacterCodingException e) {
 233                 throw new UncheckedIOException(e);
 234             }
 235             request.append(decoded);
 236             if (Pattern.compile("\r\n\r\n").matcher(request).find())
 237                 return true;
 238             buffer.clear();
 239         }
 240         return false;
 241     }
 242 
 243     private void writeResponse(SocketChannel channel, List<String> response)
 244             throws IOException
 245     {
 246         String s = response.stream().collect(Collectors.joining("\r\n"))
 247                 + "\r\n\r\n";
 248         ByteBuffer encoded;
 249         try {
 250             encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
 251         } catch (CharacterCodingException e) {
 252             throw new UncheckedIOException(e);
 253         }
 254         while (encoded.hasRemaining()) {
 255             channel.write(encoded);
 256         }
 257     }
 258 
 259     private static Function<List<String>, List<String>> defaultMapping() {
 260         return request -> {
 261             List<String> response = new LinkedList<>();
 262             Iterator<String> iterator = request.iterator();
 263             if (!iterator.hasNext()) {
 264                 throw new IllegalStateException("The request is empty");
 265             }
 266             String statusLine = iterator.next();
 267             if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
 268                 throw new IllegalStateException
 269                         ("Unexpected status line: " + request.get(0));
 270             }
 271             response.add("HTTP/1.1 101 Switching Protocols");
 272             Map<String, List<String>> requestHeaders = new HashMap<>();
 273             while (iterator.hasNext()) {
 274                 String header = iterator.next();
 275                 String[] split = header.split(": ");
 276                 if (split.length != 2) {
 277                     throw new IllegalStateException
 278                             ("Unexpected header: " + header
 279                                      + ", split=" + Arrays.toString(split));
 280                 }
 281                 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
 282 
 283             }
 284             if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
 285                 throw new IllegalStateException("Subprotocols are not expected");
 286             }
 287             if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
 288                 throw new IllegalStateException("Extensions are not expected");
 289             }
 290             expectHeader(requestHeaders, "Connection", "Upgrade");
 291             response.add("Connection: Upgrade");
 292             expectHeader(requestHeaders, "Upgrade", "websocket");
 293             response.add("Upgrade: websocket");
 294             expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
 295             List<String> key = requestHeaders.get("Sec-WebSocket-Key");
 296             if (key == null || key.isEmpty()) {
 297                 throw new IllegalStateException("Sec-WebSocket-Key is missing");
 298             }
 299             if (key.size() != 1) {
 300                 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
 301             }
 302             MessageDigest sha1 = null;
 303             try {
 304                 sha1 = MessageDigest.getInstance("SHA-1");
 305             } catch (NoSuchAlgorithmException e) {
 306                 throw new InternalError(e);
 307             }
 308             String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
 309             sha1.update(x.getBytes(ISO_8859_1));
 310             String v = Base64.getEncoder().encodeToString(sha1.digest());
 311             response.add("Sec-WebSocket-Accept: " + v);
 312             return response;
 313         };
 314     }
 315 
 316     protected static String expectHeader(Map<String, List<String>> headers,
 317                                          String name,
 318                                          String value) {
 319         List<String> v = headers.get(name);
 320         if (!v.contains(value)) {
 321             throw new IllegalStateException(
 322                     format("Expected '%s: %s', actual: '%s: %s'",
 323                            name, value, name, v)
 324             );
 325         }
 326         return value;
 327     }
 328 
 329     private static void close(AutoCloseable... acs) {
 330         for (AutoCloseable ac : acs) {
 331             try {
 332                 ac.close();
 333             } catch (Exception ignored) { }
 334         }
 335     }
 336 }