< prev index next >
1 /*
2 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23
24 import java.io.IOException;
25 import java.io.UncheckedIOException;
26 import java.net.InetSocketAddress;
27 import java.net.URI;
28 import java.nio.ByteBuffer;
29 import java.nio.CharBuffer;
30 import java.nio.channels.ServerSocketChannel;
31 import java.nio.channels.SocketChannel;
32 import java.nio.charset.CharacterCodingException;
33 import java.nio.charset.StandardCharsets;
34 import java.security.MessageDigest;
35 import java.security.NoSuchAlgorithmException;
36 import java.util.Arrays;
37 import java.util.Base64;
38 import java.util.HashMap;
39 import java.util.Iterator;
40 import java.util.LinkedList;
41 import java.util.List;
42 import java.util.Map;
43 import java.util.concurrent.CompletableFuture;
44 import java.util.function.Function;
45 import java.util.regex.Pattern;
46 import java.util.stream.Collectors;
47
48 import static java.lang.String.format;
49 import static java.util.Objects.requireNonNull;
50
51 //
52 // Performs a simple opening handshake and yields the channel.
53 //
54 // Client Request:
55 //
56 // GET /chat HTTP/1.1
57 // Host: server.example.com
58 // Upgrade: websocket
59 // Connection: Upgrade
60 // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
61 // Origin: http://example.com
62 // Sec-WebSocket-Protocol: chat, superchat
63 // Sec-WebSocket-Version: 13
64 //
65 //
66 // Server Response:
67 //
68 // HTTP/1.1 101 Switching Protocols
69 // Upgrade: websocket
70 // Connection: Upgrade
71 // Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
72 // Sec-WebSocket-Protocol: chat
73 //
74 final class HandshakePhase {
75
76 private final ServerSocketChannel ssc;
77
78 HandshakePhase(InetSocketAddress address) {
79 requireNonNull(address);
80 try {
81 ssc = ServerSocketChannel.open();
82 ssc.bind(address);
83 } catch (IOException e) {
84 throw new UncheckedIOException(e);
85 }
86 }
87
88 //
89 // Returned CF completes normally after the handshake has been performed
90 //
91 CompletableFuture<SocketChannel> afterHandshake(
92 Function<List<String>, List<String>> mapping) {
93 return CompletableFuture.supplyAsync(
94 () -> {
95 SocketChannel socketChannel = accept();
96 try {
97 StringBuilder request = new StringBuilder();
98 if (!readRequest(socketChannel, request)) {
99 throw new IllegalStateException();
100 }
101 List<String> strings = Arrays.asList(
102 request.toString().split("\r\n")
103 );
104 List<String> response = mapping.apply(strings);
105 writeResponse(socketChannel, response);
106 return socketChannel;
107 } catch (Throwable t) {
108 try {
109 socketChannel.close();
110 } catch (IOException ignored) { }
111 throw t;
112 }
113 });
114 }
115
116 CompletableFuture<SocketChannel> afterHandshake() {
117 return afterHandshake((request) -> {
118 List<String> response = new LinkedList<>();
119 Iterator<String> iterator = request.iterator();
120 if (!iterator.hasNext()) {
121 throw new IllegalStateException("The request is empty");
122 }
123 if (!"GET / HTTP/1.1".equals(iterator.next())) {
124 throw new IllegalStateException
125 ("Unexpected status line: " + request.get(0));
126 }
127 response.add("HTTP/1.1 101 Switching Protocols");
128 Map<String, String> requestHeaders = new HashMap<>();
129 while (iterator.hasNext()) {
130 String header = iterator.next();
131 String[] split = header.split(": ");
132 if (split.length != 2) {
133 throw new IllegalStateException
134 ("Unexpected header: " + header
135 + ", split=" + Arrays.toString(split));
136 }
137 if (requestHeaders.put(split[0], split[1]) != null) {
138 throw new IllegalStateException
139 ("Duplicating headers: " + Arrays.toString(split));
140 }
141 }
142 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
143 throw new IllegalStateException("Subprotocols are not expected");
144 }
145 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
146 throw new IllegalStateException("Extensions are not expected");
147 }
148 expectHeader(requestHeaders, "Connection", "Upgrade");
149 response.add("Connection: Upgrade");
150 expectHeader(requestHeaders, "Upgrade", "websocket");
151 response.add("Upgrade: websocket");
152 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
153 String key = requestHeaders.get("Sec-WebSocket-Key");
154 if (key == null) {
155 throw new IllegalStateException("Sec-WebSocket-Key is missing");
156 }
157 MessageDigest sha1 = null;
158 try {
159 sha1 = MessageDigest.getInstance("SHA-1");
160 } catch (NoSuchAlgorithmException e) {
161 throw new InternalError(e);
162 }
163 String x = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
164 sha1.update(x.getBytes(StandardCharsets.ISO_8859_1));
165 String v = Base64.getEncoder().encodeToString(sha1.digest());
166 response.add("Sec-WebSocket-Accept: " + v);
167 return response;
168 });
169 }
170
171 private String expectHeader(Map<String, String> headers,
172 String name,
173 String value) {
174 String v = headers.get(name);
175 if (!value.equals(v)) {
176 throw new IllegalStateException(
177 format("Expected '%s: %s', actual: '%s: %s'",
178 name, value, name, v)
179 );
180 }
181 return v;
182 }
183
184 URI getURI() {
185 InetSocketAddress a;
186 try {
187 a = (InetSocketAddress) ssc.getLocalAddress();
188 } catch (IOException e) {
189 throw new UncheckedIOException(e);
190 }
191 return URI.create("ws://" + a.getHostName() + ":" + a.getPort());
192 }
193
194 private int read(SocketChannel socketChannel, ByteBuffer buffer) {
195 try {
196 int num = socketChannel.read(buffer);
197 if (num == -1) {
198 throw new IllegalStateException("Unexpected EOF");
199 }
200 assert socketChannel.isBlocking() && num > 0;
201 return num;
202 } catch (IOException e) {
203 throw new UncheckedIOException(e);
204 }
205 }
206
207 private SocketChannel accept() {
208 SocketChannel socketChannel = null;
209 try {
210 socketChannel = ssc.accept();
211 socketChannel.configureBlocking(true);
212 } catch (IOException e) {
213 if (socketChannel != null) {
214 try {
215 socketChannel.close();
216 } catch (IOException ignored) { }
217 }
218 throw new UncheckedIOException(e);
219 }
220 return socketChannel;
221 }
222
223 private boolean readRequest(SocketChannel socketChannel,
224 StringBuilder request) {
225 ByteBuffer buffer = ByteBuffer.allocateDirect(512);
226 read(socketChannel, buffer);
227 CharBuffer decoded;
228 buffer.flip();
229 try {
230 decoded =
231 StandardCharsets.ISO_8859_1.newDecoder().decode(buffer);
232 } catch (CharacterCodingException e) {
233 throw new UncheckedIOException(e);
234 }
235 request.append(decoded);
236 return Pattern.compile("\r\n\r\n").matcher(request).find();
237 }
238
239 private void writeResponse(SocketChannel socketChannel,
240 List<String> response) {
241 String s = response.stream().collect(Collectors.joining("\r\n"))
242 + "\r\n\r\n";
243 ByteBuffer encoded;
244 try {
245 encoded =
246 StandardCharsets.ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s));
247 } catch (CharacterCodingException e) {
248 throw new UncheckedIOException(e);
249 }
250 write(socketChannel, encoded);
251 }
252
253 private void write(SocketChannel socketChannel, ByteBuffer buffer) {
254 try {
255 while (buffer.hasRemaining()) {
256 socketChannel.write(buffer);
257 }
258 } catch (IOException e) {
259 try {
260 socketChannel.close();
261 } catch (IOException ignored) { }
262 throw new UncheckedIOException(e);
263 }
264 }
265 }
< prev index next >