1 /* 2 * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.Closeable; 25 import java.io.IOException; 26 import java.io.UncheckedIOException; 27 import java.net.InetAddress; 28 import java.net.InetSocketAddress; 29 import java.net.StandardSocketOptions; 30 import java.net.URI; 31 import java.nio.ByteBuffer; 32 import java.nio.CharBuffer; 33 import java.nio.channels.ClosedByInterruptException; 34 import java.nio.channels.ServerSocketChannel; 35 import java.nio.channels.SocketChannel; 36 import java.nio.charset.CharacterCodingException; 37 import java.security.MessageDigest; 38 import java.security.NoSuchAlgorithmException; 39 import java.util.ArrayList; 40 import java.util.Arrays; 41 import java.util.Base64; 42 import java.util.HashMap; 43 import java.util.Iterator; 44 import java.util.LinkedList; 45 import java.util.List; 46 import java.util.Map; 47 import java.util.concurrent.CountDownLatch; 48 import java.util.concurrent.atomic.AtomicBoolean; 49 import java.util.function.Function; 50 import java.util.regex.Pattern; 51 import java.util.stream.Collectors; 52 53 import static java.lang.String.format; 54 import static java.lang.System.err; 55 import static java.nio.charset.StandardCharsets.ISO_8859_1; 56 import static java.util.Arrays.asList; 57 import static java.util.Objects.requireNonNull; 58 59 /** 60 * Dummy WebSocket Server. 61 * 62 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e. 63 * no proxying, cookies, etc.) Supports sequential connections, one at a time, 64 * i.e. in order for a client to connect to the server the previous client must 65 * disconnect first. 66 * 67 * Expected client request: 68 * 69 * GET /chat HTTP/1.1 70 * Host: server.example.com 71 * Upgrade: websocket 72 * Connection: Upgrade 73 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== 74 * Origin: http://example.com 75 * Sec-WebSocket-Protocol: chat, superchat 76 * Sec-WebSocket-Version: 13 77 * 78 * This server response: 79 * 80 * HTTP/1.1 101 Switching Protocols 81 * Upgrade: websocket 82 * Connection: Upgrade 83 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 84 * Sec-WebSocket-Protocol: chat 85 */ 86 public class DummyWebSocketServer implements Closeable { 87 88 private final AtomicBoolean started = new AtomicBoolean(); 89 private final Thread thread; 90 private volatile ServerSocketChannel ssc; 91 private volatile InetSocketAddress address; 92 private ByteBuffer read = ByteBuffer.allocate(16384); 93 private final CountDownLatch readReady = new CountDownLatch(1); 94 95 public DummyWebSocketServer() { 96 this(defaultMapping()); 97 } 98 99 public DummyWebSocketServer(Function<List<String>, List<String>> mapping) { 100 requireNonNull(mapping); 101 thread = new Thread(() -> { 102 try { 103 while (!Thread.currentThread().isInterrupted()) { 104 err.println("Accepting next connection at: " + ssc); 105 SocketChannel channel = ssc.accept(); 106 err.println("Accepted: " + channel); 107 try { 108 channel.setOption(StandardSocketOptions.TCP_NODELAY, true); 109 channel.configureBlocking(true); 110 StringBuilder request = new StringBuilder(); 111 if (!readRequest(channel, request)) { 112 throw new IOException("Bad request:" + request); 113 } 114 List<String> strings = asList(request.toString().split("\r\n")); 115 List<String> response = mapping.apply(strings); 116 writeResponse(channel, response); 117 serve(channel); 118 } catch (IOException e) { 119 err.println("Error in connection: " + channel + ", " + e); 120 } finally { 121 err.println("Closed: " + channel); 122 close(channel); 123 readReady.countDown(); 124 } 125 } 126 } catch (ClosedByInterruptException ignored) { 127 } catch (Exception e) { 128 err.println(e); 129 } finally { 130 close(ssc); 131 err.println("Stopped at: " + getURI()); 132 } 133 }); 134 thread.setName("DummyWebSocketServer"); 135 thread.setDaemon(false); 136 } 137 138 protected void read(SocketChannel ch) throws IOException { 139 // Read until the thread is interrupted or an error occurred 140 // or the input is shutdown 141 ByteBuffer b = ByteBuffer.allocate(65536); 142 while (ch.read(b) != -1) { 143 b.flip(); 144 if (read.remaining() < b.remaining()) { 145 int required = read.capacity() - read.remaining() + b.remaining(); 146 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1); 147 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required); 148 newBuffer.put(read.flip()); 149 read = newBuffer; 150 } 151 read.put(b); 152 b.clear(); 153 } 154 } 155 156 protected void write(SocketChannel ch) throws IOException { } 157 158 protected final void serve(SocketChannel channel) 159 throws InterruptedException 160 { 161 Thread reader = new Thread(() -> { 162 try { 163 read(channel); 164 } catch (IOException ignored) { } 165 }); 166 Thread writer = new Thread(() -> { 167 try { 168 write(channel); 169 } catch (IOException ignored) { } 170 }); 171 reader.start(); 172 writer.start(); 173 try { 174 reader.join(); 175 } finally { 176 reader.interrupt(); 177 try { 178 writer.join(); 179 } finally { 180 writer.interrupt(); 181 } 182 } 183 } 184 185 public ByteBuffer read() throws InterruptedException { 186 readReady.await(); 187 return read.duplicate().asReadOnlyBuffer().flip(); 188 } 189 190 public void open() throws IOException { 191 err.println("Starting"); 192 if (!started.compareAndSet(false, true)) { 193 throw new IllegalStateException("Already started"); 194 } 195 ssc = ServerSocketChannel.open(); 196 try { 197 ssc.configureBlocking(true); 198 ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); 199 address = (InetSocketAddress) ssc.getLocalAddress(); 200 thread.start(); 201 } catch (IOException e) { 202 close(ssc); 203 throw e; 204 } 205 err.println("Started at: " + getURI()); 206 } 207 208 @Override 209 public void close() { 210 err.println("Stopping: " + getURI()); 211 thread.interrupt(); 212 close(ssc); 213 } 214 215 URI getURI() { 216 if (!started.get()) { 217 throw new IllegalStateException("Not yet started"); 218 } 219 return URI.create("ws://localhost:" + address.getPort()); 220 } 221 222 private boolean readRequest(SocketChannel channel, StringBuilder request) 223 throws IOException 224 { 225 ByteBuffer buffer = ByteBuffer.allocate(512); 226 while (channel.read(buffer) != -1) { 227 // read the complete HTTP request headers, there should be no body 228 CharBuffer decoded; 229 buffer.flip(); 230 try { 231 decoded = ISO_8859_1.newDecoder().decode(buffer); 232 } catch (CharacterCodingException e) { 233 throw new UncheckedIOException(e); 234 } 235 request.append(decoded); 236 if (Pattern.compile("\r\n\r\n").matcher(request).find()) 237 return true; 238 buffer.clear(); 239 } 240 return false; 241 } 242 243 private void writeResponse(SocketChannel channel, List<String> response) 244 throws IOException 245 { 246 String s = response.stream().collect(Collectors.joining("\r\n")) 247 + "\r\n\r\n"; 248 ByteBuffer encoded; 249 try { 250 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); 251 } catch (CharacterCodingException e) { 252 throw new UncheckedIOException(e); 253 } 254 while (encoded.hasRemaining()) { 255 channel.write(encoded); 256 } 257 } 258 259 private static Function<List<String>, List<String>> defaultMapping() { 260 return request -> { 261 List<String> response = new LinkedList<>(); 262 Iterator<String> iterator = request.iterator(); 263 if (!iterator.hasNext()) { 264 throw new IllegalStateException("The request is empty"); 265 } 266 String statusLine = iterator.next(); 267 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) { 268 throw new IllegalStateException 269 ("Unexpected status line: " + request.get(0)); 270 } 271 response.add("HTTP/1.1 101 Switching Protocols"); 272 Map<String, List<String>> requestHeaders = new HashMap<>(); 273 while (iterator.hasNext()) { 274 String header = iterator.next(); 275 String[] split = header.split(": "); 276 if (split.length != 2) { 277 throw new IllegalStateException 278 ("Unexpected header: " + header 279 + ", split=" + Arrays.toString(split)); 280 } 281 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]); 282 283 } 284 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { 285 throw new IllegalStateException("Subprotocols are not expected"); 286 } 287 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { 288 throw new IllegalStateException("Extensions are not expected"); 289 } 290 expectHeader(requestHeaders, "Connection", "Upgrade"); 291 response.add("Connection: Upgrade"); 292 expectHeader(requestHeaders, "Upgrade", "websocket"); 293 response.add("Upgrade: websocket"); 294 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); 295 List<String> key = requestHeaders.get("Sec-WebSocket-Key"); 296 if (key == null || key.isEmpty()) { 297 throw new IllegalStateException("Sec-WebSocket-Key is missing"); 298 } 299 if (key.size() != 1) { 300 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key); 301 } 302 MessageDigest sha1 = null; 303 try { 304 sha1 = MessageDigest.getInstance("SHA-1"); 305 } catch (NoSuchAlgorithmException e) { 306 throw new InternalError(e); 307 } 308 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 309 sha1.update(x.getBytes(ISO_8859_1)); 310 String v = Base64.getEncoder().encodeToString(sha1.digest()); 311 response.add("Sec-WebSocket-Accept: " + v); 312 return response; 313 }; 314 } 315 316 protected static String expectHeader(Map<String, List<String>> headers, 317 String name, 318 String value) { 319 List<String> v = headers.get(name); 320 if (!v.contains(value)) { 321 throw new IllegalStateException( 322 format("Expected '%s: %s', actual: '%s: %s'", 323 name, value, name, v) 324 ); 325 } 326 return value; 327 } 328 329 private static void close(AutoCloseable... acs) { 330 for (AutoCloseable ac : acs) { 331 try { 332 ac.close(); 333 } catch (Exception ignored) { } 334 } 335 } 336 }