1 /* 2 * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 import java.io.Closeable; 25 import java.io.IOException; 26 import java.io.UncheckedIOException; 27 import java.net.InetSocketAddress; 28 import java.net.URI; 29 import java.nio.ByteBuffer; 30 import java.nio.CharBuffer; 31 import java.nio.channels.ClosedByInterruptException; 32 import java.nio.channels.ServerSocketChannel; 33 import java.nio.channels.SocketChannel; 34 import java.nio.charset.CharacterCodingException; 35 import java.security.MessageDigest; 36 import java.security.NoSuchAlgorithmException; 37 import java.util.Arrays; 38 import java.util.Base64; 39 import java.util.HashMap; 40 import java.util.Iterator; 41 import java.util.LinkedList; 42 import java.util.List; 43 import java.util.Map; 44 import java.util.concurrent.atomic.AtomicBoolean; 45 import java.util.function.Function; 46 import java.util.regex.Pattern; 47 import java.util.stream.Collectors; 48 49 import static java.lang.String.format; 50 import static java.lang.System.Logger.Level.ERROR; 51 import static java.lang.System.Logger.Level.INFO; 52 import static java.lang.System.Logger.Level.TRACE; 53 import static java.nio.charset.StandardCharsets.ISO_8859_1; 54 import static java.util.Arrays.asList; 55 import static java.util.Objects.requireNonNull; 56 57 /** 58 * Dummy WebSocket Server. 59 * 60 * Performs simpler version of the WebSocket Opening Handshake over HTTP (i.e. 61 * no proxying, cookies, etc.) Supports sequential connections, one at a time, 62 * i.e. in order for a client to connect to the server the previous client must 63 * disconnect first. 64 * 65 * Expected client request: 66 * 67 * GET /chat HTTP/1.1 68 * Host: server.example.com 69 * Upgrade: websocket 70 * Connection: Upgrade 71 * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== 72 * Origin: http://example.com 73 * Sec-WebSocket-Protocol: chat, superchat 74 * Sec-WebSocket-Version: 13 75 * 76 * This server response: 77 * 78 * HTTP/1.1 101 Switching Protocols 79 * Upgrade: websocket 80 * Connection: Upgrade 81 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 82 * Sec-WebSocket-Protocol: chat 83 */ 84 public final class DummyWebSocketServer implements Closeable { 85 86 private final static System.Logger log = System.getLogger(DummyWebSocketServer.class.getName()); 87 private final AtomicBoolean started = new AtomicBoolean(); 88 private final Thread thread; 89 private volatile ServerSocketChannel ssc; 90 private volatile InetSocketAddress address; 91 92 public DummyWebSocketServer() { 93 this(defaultMapping()); 94 } 95 96 public DummyWebSocketServer(Function<List<String>, List<String>> mapping) { 97 requireNonNull(mapping); 98 thread = new Thread(() -> { 99 try { 100 while (!Thread.currentThread().isInterrupted()) { 101 log.log(INFO, "Accepting next connection at: " + ssc); 102 SocketChannel channel = ssc.accept(); 103 log.log(INFO, "Accepted: " + channel); 104 try { 105 channel.configureBlocking(true); 106 StringBuilder request = new StringBuilder(); 107 if (!readRequest(channel, request)) { 108 throw new IOException("Bad request:" + request); 109 } 110 List<String> strings = asList(request.toString().split("\r\n")); 111 List<String> response = mapping.apply(strings); 112 writeResponse(channel, response); 113 // Read until the thread is interrupted or an error occurred 114 // or the input is shutdown 115 ByteBuffer b = ByteBuffer.allocate(1024); 116 while (channel.read(b) != -1) { 117 b.clear(); 118 } 119 } catch (IOException e) { 120 log.log(TRACE, () -> "Error in connection: " + channel, e); 121 } finally { 122 log.log(INFO, "Closed: " + channel); 123 close(channel); 124 } 125 } 126 } catch (ClosedByInterruptException ignored) { 127 } catch (IOException e) { 128 log.log(ERROR, e); 129 } finally { 130 close(ssc); 131 log.log(INFO, "Stopped at: " + getURI()); 132 } 133 }); 134 thread.setName("DummyWebSocketServer"); 135 thread.setDaemon(false); 136 } 137 138 public void open() throws IOException { 139 log.log(INFO, "Starting"); 140 if (!started.compareAndSet(false, true)) { 141 throw new IllegalStateException("Already started"); 142 } 143 ssc = ServerSocketChannel.open(); 144 try { 145 ssc.configureBlocking(true); 146 ssc.bind(new InetSocketAddress("localhost", 0)); 147 address = (InetSocketAddress) ssc.getLocalAddress(); 148 thread.start(); 149 } catch (IOException e) { 150 close(ssc); 151 } 152 log.log(INFO, "Started at: " + getURI()); 153 } 154 155 @Override 156 public void close() { 157 log.log(INFO, "Stopping: " + getURI()); 158 thread.interrupt(); 159 close(ssc); 160 } 161 162 URI getURI() { 163 if (!started.get()) { 164 throw new IllegalStateException("Not yet started"); 165 } 166 return URI.create("ws://" + address.getHostName() + ":" + address.getPort()); 167 } 168 169 private boolean readRequest(SocketChannel channel, StringBuilder request) 170 throws IOException 171 { 172 ByteBuffer buffer = ByteBuffer.allocate(512); 173 while (channel.read(buffer) != -1) { 174 // read the complete HTTP request headers, there should be no body 175 CharBuffer decoded; 176 buffer.flip(); 177 try { 178 decoded = ISO_8859_1.newDecoder().decode(buffer); 179 } catch (CharacterCodingException e) { 180 throw new UncheckedIOException(e); 181 } 182 request.append(decoded); 183 if (Pattern.compile("\r\n\r\n").matcher(request).find()) 184 return true; 185 buffer.clear(); 186 } 187 return false; 188 } 189 190 private void writeResponse(SocketChannel channel, List<String> response) 191 throws IOException 192 { 193 String s = response.stream().collect(Collectors.joining("\r\n")) 194 + "\r\n\r\n"; 195 ByteBuffer encoded; 196 try { 197 encoded = ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); 198 } catch (CharacterCodingException e) { 199 throw new UncheckedIOException(e); 200 } 201 while (encoded.hasRemaining()) { 202 channel.write(encoded); 203 } 204 } 205 206 private static Function<List<String>, List<String>> defaultMapping() { 207 return request -> { 208 List<String> response = new LinkedList<>(); 209 Iterator<String> iterator = request.iterator(); 210 if (!iterator.hasNext()) { 211 throw new IllegalStateException("The request is empty"); 212 } 213 if (!"GET / HTTP/1.1".equals(iterator.next())) { 214 throw new IllegalStateException 215 ("Unexpected status line: " + request.get(0)); 216 } 217 response.add("HTTP/1.1 101 Switching Protocols"); 218 Map<String, String> requestHeaders = new HashMap<>(); 219 while (iterator.hasNext()) { 220 String header = iterator.next(); 221 String[] split = header.split(": "); 222 if (split.length != 2) { 223 throw new IllegalStateException 224 ("Unexpected header: " + header 225 + ", split=" + Arrays.toString(split)); 226 } 227 if (requestHeaders.put(split[0], split[1]) != null) { 228 throw new IllegalStateException 229 ("Duplicating headers: " + Arrays.toString(split)); 230 } 231 } 232 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { 233 throw new IllegalStateException("Subprotocols are not expected"); 234 } 235 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { 236 throw new IllegalStateException("Extensions are not expected"); 237 } 238 expectHeader(requestHeaders, "Connection", "Upgrade"); 239 response.add("Connection: Upgrade"); 240 expectHeader(requestHeaders, "Upgrade", "websocket"); 241 response.add("Upgrade: websocket"); 242 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); 243 String key = requestHeaders.get("Sec-WebSocket-Key"); 244 if (key == null) { 245 throw new IllegalStateException("Sec-WebSocket-Key is missing"); 246 } 247 MessageDigest sha1 = null; 248 try { 249 sha1 = MessageDigest.getInstance("SHA-1"); 250 } catch (NoSuchAlgorithmException e) { 251 throw new InternalError(e); 252 } 253 String x = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 254 sha1.update(x.getBytes(ISO_8859_1)); 255 String v = Base64.getEncoder().encodeToString(sha1.digest()); 256 response.add("Sec-WebSocket-Accept: " + v); 257 return response; 258 }; 259 } 260 261 protected static String expectHeader(Map<String, String> headers, 262 String name, 263 String value) { 264 String v = headers.get(name); 265 if (!value.equals(v)) { 266 throw new IllegalStateException( 267 format("Expected '%s: %s', actual: '%s: %s'", 268 name, value, name, v) 269 ); 270 } 271 return v; 272 } 273 274 private static void close(AutoCloseable... acs) { 275 for (AutoCloseable ac : acs) { 276 try { 277 ac.close(); 278 } catch (Exception ignored) { } 279 } 280 } 281 }