1 /* 2 * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.Closeable; 25 import java.io.IOException; 26 import java.io.UncheckedIOException; 27 import java.net.InetSocketAddress; 28 import java.net.URI; 29 import java.nio.ByteBuffer; 30 import java.nio.CharBuffer; 31 import java.nio.channels.ClosedByInterruptException; 32 import java.nio.channels.ServerSocketChannel; 33 import java.nio.channels.SocketChannel; 34 import java.nio.charset.CharacterCodingException; 35 import java.security.MessageDigest; 36 import java.security.NoSuchAlgorithmException; 37 import java.util.ArrayList; 38 import java.util.Arrays; 39 import java.util.Base64; 40 import java.util.HashMap; 41 import java.util.Iterator; 42 import java.util.LinkedList; 43 import java.util.List; 44 import java.util.Map; 45 import java.util.concurrent.atomic.AtomicBoolean; 46 import java.util.function.Function; 47 import java.util.regex.Pattern; 48 import java.util.stream.Collectors; 49 50 import static java.lang.String.format; 51 import static java.lang.System.err; 52 import static java.nio.charset.StandardCharsets.ISO_8859_1; 53 import static java.util.Arrays.asList; 54 import static java.util.Objects.requireNonNull; 55 56 /** 57 * Dummy WebSocket Server. 58 * 59 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e. 60 * no proxying, cookies, etc.) Supports sequential connections, one at a time, 61 * i.e. in order for a client to connect to the server the previous client must 62 * disconnect first. 63 * 64 * Expected client request: 65 * 66 * GET /chat HTTP/1.1 67 * Host: server.example.com 68 * Upgrade: websocket 69 * Connection: Upgrade 70 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== 71 * Origin: http://example.com 72 * Sec-WebSocket-Protocol: chat, superchat 73 * Sec-WebSocket-Version: 13 74 * 75 * This server response: 76 * 77 * HTTP/1.1 101 Switching Protocols 78 * Upgrade: websocket 79 * Connection: Upgrade 80 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 81 * Sec-WebSocket-Protocol: chat 82 */ 83 public final class DummyWebSocketServer implements Closeable { 84 85 private final AtomicBoolean started = new AtomicBoolean(); 86 private final Thread thread; 87 private volatile ServerSocketChannel ssc; 88 private volatile InetSocketAddress address; 89 90 public DummyWebSocketServer() { 91 this(defaultMapping()); 92 } 93 94 public DummyWebSocketServer(Function<List<String>, List<String>> mapping) { 95 requireNonNull(mapping); 96 thread = new Thread(() -> { 97 try { 98 while (!Thread.currentThread().isInterrupted()) { 99 err.println("Accepting next connection at: " + ssc); 100 SocketChannel channel = ssc.accept(); 101 err.println("Accepted: " + channel); 102 try { 103 channel.configureBlocking(true); 104 StringBuilder request = new StringBuilder(); 105 if (!readRequest(channel, request)) { 106 throw new IOException("Bad request:" + request); 107 } 108 List<String> strings = asList(request.toString().split("\r\n")); 109 List<String> response = mapping.apply(strings); 110 writeResponse(channel, response); 111 // Read until the thread is interrupted or an error occurred 112 // or the input is shutdown 113 ByteBuffer b = ByteBuffer.allocate(1024); 114 while (channel.read(b) != -1) { 115 b.clear(); 116 } 117 } catch (IOException e) { 118 err.println("Error in connection: " + channel + ", " + e); 119 } finally { 120 err.println("Closed: " + channel); 121 close(channel); 122 } 123 } 124 } catch (ClosedByInterruptException ignored) { 125 } catch (IOException e) { 126 err.println(e); 127 } finally { 128 close(ssc); 129 err.println("Stopped at: " + getURI()); 130 } 131 }); 132 thread.setName("DummyWebSocketServer"); 133 thread.setDaemon(false); 134 } 135 136 public void open() throws IOException { 137 err.println("Starting"); 138 if (!started.compareAndSet(false, true)) { 139 throw new IllegalStateException("Already started"); 140 } 141 ssc = ServerSocketChannel.open(); 142 try { 143 ssc.configureBlocking(true); 144 ssc.bind(new InetSocketAddress("localhost", 0)); 145 address = (InetSocketAddress) ssc.getLocalAddress(); 146 thread.start(); 147 } catch (IOException e) { 148 close(ssc); 149 } 150 err.println("Started at: " + getURI()); 151 } 152 153 @Override 154 public void close() { 155 err.println("Stopping: " + getURI()); 156 thread.interrupt(); 157 close(ssc); 158 } 159 160 URI getURI() { 161 if (!started.get()) { 162 throw new IllegalStateException("Not yet started"); 163 } 164 return URI.create("ws://" + address.getHostName() + ":" + address.getPort()); 165 } 166 167 private boolean readRequest(SocketChannel channel, StringBuilder request) 168 throws IOException 169 { 170 ByteBuffer buffer = ByteBuffer.allocate(512); 171 while (channel.read(buffer) != -1) { 172 // read the complete HTTP request headers, there should be no body 173 CharBuffer decoded; 174 buffer.flip(); 175 try { 176 decoded = ISO_8859_1.newDecoder().decode(buffer); 177 } catch (CharacterCodingException e) { 178 throw new UncheckedIOException(e); 179 } 180 request.append(decoded); 181 if (Pattern.compile("\r\n\r\n").matcher(request).find()) 182 return true; 183 buffer.clear(); 184 } 185 return false; 186 } 187 188 private void writeResponse(SocketChannel channel, List<String> response) 189 throws IOException 190 { 191 String s = response.stream().collect(Collectors.joining("\r\n")) 192 + "\r\n\r\n"; 193 ByteBuffer encoded; 194 try { 195 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); 196 } catch (CharacterCodingException e) { 197 throw new UncheckedIOException(e); 198 } 199 while (encoded.hasRemaining()) { 200 channel.write(encoded); 201 } 202 } 203 204 private static Function<List<String>, List<String>> defaultMapping() { 205 return request -> { 206 List<String> response = new LinkedList<>(); 207 Iterator<String> iterator = request.iterator(); 208 if (!iterator.hasNext()) { 209 throw new IllegalStateException("The request is empty"); 210 } 211 String statusLine = iterator.next(); 212 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) { 213 throw new IllegalStateException 214 ("Unexpected status line: " + request.get(0)); 215 } 216 response.add("HTTP/1.1 101 Switching Protocols"); 217 Map<String, List<String>> requestHeaders = new HashMap<>(); 218 while (iterator.hasNext()) { 219 String header = iterator.next(); 220 String[] split = header.split(": "); 221 if (split.length != 2) { 222 throw new IllegalStateException 223 ("Unexpected header: " + header 224 + ", split=" + Arrays.toString(split)); 225 } 226 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]); 227 228 } 229 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { 230 throw new IllegalStateException("Subprotocols are not expected"); 231 } 232 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { 233 throw new IllegalStateException("Extensions are not expected"); 234 } 235 expectHeader(requestHeaders, "Connection", "Upgrade"); 236 response.add("Connection: Upgrade"); 237 expectHeader(requestHeaders, "Upgrade", "websocket"); 238 response.add("Upgrade: websocket"); 239 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); 240 List<String> key = requestHeaders.get("Sec-WebSocket-Key"); 241 if (key == null || key.isEmpty()) { 242 throw new IllegalStateException("Sec-WebSocket-Key is missing"); 243 } 244 if (key.size() != 1) { 245 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key); 246 } 247 MessageDigest sha1 = null; 248 try { 249 sha1 = MessageDigest.getInstance("SHA-1"); 250 } catch (NoSuchAlgorithmException e) { 251 throw new InternalError(e); 252 } 253 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 254 sha1.update(x.getBytes(ISO_8859_1)); 255 String v = Base64.getEncoder().encodeToString(sha1.digest()); 256 response.add("Sec-WebSocket-Accept: " + v); 257 return response; 258 }; 259 } 260 261 protected static String expectHeader(Map<String, List<String>> headers, 262 String name, 263 String value) { 264 List<String> v = headers.get(name); 265 if (!v.contains(value)) { 266 throw new IllegalStateException( 267 format("Expected '%s: %s', actual: '%s: %s'", 268 name, value, name, v) 269 ); 270 } 271 return value; 272 } 273 274 private static void close(AutoCloseable... acs) { 275 for (AutoCloseable ac : acs) { 276 try { 277 ac.close(); 278 } catch (Exception ignored) { } 279 } 280 } 281 }