1 /* 2 * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.Closeable; 25 import java.io.IOException; 26 import java.io.UncheckedIOException; 27 import java.net.InetAddress; 28 import java.net.InetSocketAddress; 29 import java.net.StandardSocketOptions; 30 import java.net.URI; 31 import java.nio.ByteBuffer; 32 import java.nio.CharBuffer; 33 import java.nio.channels.ClosedByInterruptException; 34 import java.nio.channels.ServerSocketChannel; 35 import java.nio.channels.SocketChannel; 36 import java.nio.charset.CharacterCodingException; 37 import java.security.MessageDigest; 38 import java.security.NoSuchAlgorithmException; 39 import java.util.ArrayList; 40 import java.util.Arrays; 41 import java.util.Base64; 42 import java.util.HashMap; 43 import java.util.Iterator; 44 import java.util.LinkedList; 45 import java.util.List; 46 import java.util.Map; 47 import java.util.concurrent.CountDownLatch; 48 import java.util.concurrent.atomic.AtomicBoolean; 49 import java.util.function.BiFunction; 50 import java.util.regex.Pattern; 51 import java.util.stream.Collectors; 52 53 import static java.lang.String.format; 54 import static java.lang.System.err; 55 import static java.nio.charset.StandardCharsets.ISO_8859_1; 56 import static java.nio.charset.StandardCharsets.UTF_8; 57 import static java.util.Arrays.asList; 58 import static java.util.Objects.requireNonNull; 59 60 /** 61 * Dummy WebSocket Server. 62 * 63 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e. 64 * no proxying, cookies, etc.) Supports sequential connections, one at a time, 65 * i.e. in order for a client to connect to the server the previous client must 66 * disconnect first. 67 * 68 * Expected client request: 69 * 70 * GET /chat HTTP/1.1 71 * Host: server.example.com 72 * Upgrade: websocket 73 * Connection: Upgrade 74 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== 75 * Origin: http://example.com 76 * Sec-WebSocket-Protocol: chat, superchat 77 * Sec-WebSocket-Version: 13 78 * 79 * This server response: 80 * 81 * HTTP/1.1 101 Switching Protocols 82 * Upgrade: websocket 83 * Connection: Upgrade 84 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 85 * Sec-WebSocket-Protocol: chat 86 */ 87 public class DummyWebSocketServer implements Closeable { 88 89 private final AtomicBoolean started = new AtomicBoolean(); 90 private final Thread thread; 91 private volatile ServerSocketChannel ssc; 92 private volatile InetSocketAddress address; 93 private ByteBuffer read = ByteBuffer.allocate(16384); 94 private final CountDownLatch readReady = new CountDownLatch(1); 95 96 private static class Credentials { 97 private final String name; 98 private final String password; 99 private Credentials(String name, String password) { 100 this.name = name; 101 this.password = password; 102 } 103 public String name() { return name; } 104 public String password() { return password; } 105 } 106 107 public DummyWebSocketServer() { 108 this(defaultMapping(), null, null); 109 } 110 111 public DummyWebSocketServer(String username, String password) { 112 this(defaultMapping(), username, password); 113 } 114 115 public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, 116 String username, 117 String password) { 118 requireNonNull(mapping); 119 Credentials credentials = username != null ? 120 new Credentials(username, password) : null; 121 122 thread = new Thread(() -> { 123 try { 124 while (!Thread.currentThread().isInterrupted()) { 125 err.println("Accepting next connection at: " + ssc); 126 SocketChannel channel = ssc.accept(); 127 err.println("Accepted: " + channel); 128 try { 129 channel.setOption(StandardSocketOptions.TCP_NODELAY, true); 130 channel.configureBlocking(true); 131 while (true) { 132 StringBuilder request = new StringBuilder(); 133 if (!readRequest(channel, request)) { 134 throw new IOException("Bad request:[" + request + "]"); 135 } 136 List<String> strings = asList(request.toString().split("\r\n")); 137 List<String> response = mapping.apply(strings, credentials); 138 writeResponse(channel, response); 139 140 if (response.get(0).startsWith("HTTP/1.1 401")) { 141 err.println("Sent 401 Authentication response " + channel); 142 continue; 143 } else { 144 serve(channel); 145 break; 146 } 147 } 148 } catch (IOException e) { 149 err.println("Error in connection: " + channel + ", " + e); 150 } finally { 151 err.println("Closed: " + channel); 152 close(channel); 153 readReady.countDown(); 154 } 155 } 156 } catch (ClosedByInterruptException ignored) { 157 } catch (Exception e) { 158 e.printStackTrace(err); 159 } finally { 160 close(ssc); 161 err.println("Stopped at: " + getURI()); 162 } 163 }); 164 thread.setName("DummyWebSocketServer"); 165 thread.setDaemon(false); 166 } 167 168 protected void read(SocketChannel ch) throws IOException { 169 // Read until the thread is interrupted or an error occurred 170 // or the input is shutdown 171 ByteBuffer b = ByteBuffer.allocate(65536); 172 while (ch.read(b) != -1) { 173 b.flip(); 174 if (read.remaining() < b.remaining()) { 175 int required = read.capacity() - read.remaining() + b.remaining(); 176 int log2required = 32 - Integer.numberOfLeadingZeros(required - 1); 177 ByteBuffer newBuffer = ByteBuffer.allocate(1 << log2required); 178 newBuffer.put(read.flip()); 179 read = newBuffer; 180 } 181 read.put(b); 182 b.clear(); 183 } 184 } 185 186 protected void write(SocketChannel ch) throws IOException { } 187 188 protected final void serve(SocketChannel channel) 189 throws InterruptedException 190 { 191 Thread reader = new Thread(() -> { 192 try { 193 read(channel); 194 } catch (IOException ignored) { } 195 }); 196 Thread writer = new Thread(() -> { 197 try { 198 write(channel); 199 } catch (IOException ignored) { } 200 }); 201 reader.start(); 202 writer.start(); 203 try { 204 reader.join(); 205 } finally { 206 reader.interrupt(); 207 try { 208 writer.join(); 209 } finally { 210 writer.interrupt(); 211 } 212 } 213 } 214 215 public ByteBuffer read() throws InterruptedException { 216 readReady.await(); 217 return read.duplicate().asReadOnlyBuffer().flip(); 218 } 219 220 public void open() throws IOException { 221 err.println("Starting"); 222 if (!started.compareAndSet(false, true)) { 223 throw new IllegalStateException("Already started"); 224 } 225 ssc = ServerSocketChannel.open(); 226 try { 227 ssc.configureBlocking(true); 228 ssc.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); 229 address = (InetSocketAddress) ssc.getLocalAddress(); 230 thread.start(); 231 } catch (IOException e) { 232 close(ssc); 233 throw e; 234 } 235 err.println("Started at: " + getURI()); 236 } 237 238 @Override 239 public void close() { 240 err.println("Stopping: " + getURI()); 241 thread.interrupt(); 242 close(ssc); 243 } 244 245 URI getURI() { 246 if (!started.get()) { 247 throw new IllegalStateException("Not yet started"); 248 } 249 return URI.create("ws://localhost:" + address.getPort()); 250 } 251 252 private boolean readRequest(SocketChannel channel, StringBuilder request) 253 throws IOException 254 { 255 ByteBuffer buffer = ByteBuffer.allocate(512); 256 while (channel.read(buffer) != -1) { 257 // read the complete HTTP request headers, there should be no body 258 CharBuffer decoded; 259 buffer.flip(); 260 try { 261 decoded = ISO_8859_1.newDecoder().decode(buffer); 262 } catch (CharacterCodingException e) { 263 throw new UncheckedIOException(e); 264 } 265 request.append(decoded); 266 if (Pattern.compile("\r\n\r\n").matcher(request).find()) 267 return true; 268 buffer.clear(); 269 } 270 return false; 271 } 272 273 private void writeResponse(SocketChannel channel, List<String> response) 274 throws IOException 275 { 276 String s = response.stream().collect(Collectors.joining("\r\n")) 277 + "\r\n\r\n"; 278 ByteBuffer encoded; 279 try { 280 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); 281 } catch (CharacterCodingException e) { 282 throw new UncheckedIOException(e); 283 } 284 while (encoded.hasRemaining()) { 285 channel.write(encoded); 286 } 287 } 288 289 private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() { 290 return (request, credentials) -> { 291 List<String> response = new LinkedList<>(); 292 Iterator<String> iterator = request.iterator(); 293 if (!iterator.hasNext()) { 294 throw new IllegalStateException("The request is empty"); 295 } 296 String statusLine = iterator.next(); 297 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) { 298 throw new IllegalStateException 299 ("Unexpected status line: " + request.get(0)); 300 } 301 response.add("HTTP/1.1 101 Switching Protocols"); 302 Map<String, List<String>> requestHeaders = new HashMap<>(); 303 while (iterator.hasNext()) { 304 String header = iterator.next(); 305 String[] split = header.split(": "); 306 if (split.length != 2) { 307 throw new IllegalStateException 308 ("Unexpected header: " + header 309 + ", split=" + Arrays.toString(split)); 310 } 311 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]); 312 313 } 314 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { 315 throw new IllegalStateException("Subprotocols are not expected"); 316 } 317 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { 318 throw new IllegalStateException("Extensions are not expected"); 319 } 320 expectHeader(requestHeaders, "Connection", "Upgrade"); 321 response.add("Connection: Upgrade"); 322 expectHeader(requestHeaders, "Upgrade", "websocket"); 323 response.add("Upgrade: websocket"); 324 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); 325 List<String> key = requestHeaders.get("Sec-WebSocket-Key"); 326 if (key == null || key.isEmpty()) { 327 throw new IllegalStateException("Sec-WebSocket-Key is missing"); 328 } 329 if (key.size() != 1) { 330 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key); 331 } 332 MessageDigest sha1 = null; 333 try { 334 sha1 = MessageDigest.getInstance("SHA-1"); 335 } catch (NoSuchAlgorithmException e) { 336 throw new InternalError(e); 337 } 338 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 339 sha1.update(x.getBytes(ISO_8859_1)); 340 String v = Base64.getEncoder().encodeToString(sha1.digest()); 341 response.add("Sec-WebSocket-Accept: " + v); 342 343 // check authorization credentials, if required by the server 344 if (credentials != null && !authorized(credentials, requestHeaders)) { 345 response.clear(); 346 response.add("HTTP/1.1 401 Unauthorized"); 347 response.add("Content-Length: 0"); 348 response.add("WWW-Authenticate: Basic realm=\"dummy server realm\""); 349 } 350 351 return response; 352 }; 353 } 354 355 // Checks credentials in the request against those allowable by the server. 356 private static boolean authorized(Credentials credentials, 357 Map<String,List<String>> requestHeaders) { 358 List<String> authorization = requestHeaders.get("Authorization"); 359 if (authorization == null) 360 return false; 361 362 if (authorization.size() != 1) { 363 throw new IllegalStateException("Authorization unexpected count:" + authorization); 364 } 365 String header = authorization.get(0); 366 if (!header.startsWith("Basic ")) 367 throw new IllegalStateException("Authorization not Basic: " + header); 368 369 header = header.substring("Basic ".length()); 370 String values = new String(Base64.getDecoder().decode(header), UTF_8); 371 int sep = values.indexOf(':'); 372 if (sep < 1) { 373 throw new IllegalStateException("Authorization not colon: " + values); 374 } 375 String name = values.substring(0, sep); 376 String password = values.substring(sep + 1); 377 378 if (name.equals(credentials.name()) && password.equals(credentials.password())) 379 return true; 380 381 return false; 382 } 383 384 protected static String expectHeader(Map<String, List<String>> headers, 385 String name, 386 String value) { 387 List<String> v = headers.get(name); 388 if (v == null) { 389 throw new IllegalStateException( 390 format("Expected '%s' header, not present in %s", 391 name, headers)); 392 } 393 if (!v.contains(value)) { 394 throw new IllegalStateException( 395 format("Expected '%s: %s', actual: '%s: %s'", 396 name, value, name, v) 397 ); 398 } 399 return value; 400 } 401 402 private static void close(AutoCloseable... acs) { 403 for (AutoCloseable ac : acs) { 404 try { 405 ac.close(); 406 } catch (Exception ignored) { } 407 } 408 } 409 }