1 /* |
|
2 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. |
|
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. |
|
4 * |
|
5 * This code is free software; you can redistribute it and/or modify it |
|
6 * under the terms of the GNU General Public License version 2 only, as |
|
7 * published by the Free Software Foundation. |
|
8 * |
|
9 * This code is distributed in the hope that it will be useful, but WITHOUT |
|
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
|
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License |
|
12 * version 2 for more details (a copy is included in the LICENSE file that |
|
13 * accompanied this code). |
|
14 * |
|
15 * You should have received a copy of the GNU General Public License version |
|
16 * 2 along with this work; if not, write to the Free Software Foundation, |
|
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. |
|
18 * |
|
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA |
|
20 * or visit www.oracle.com if you need additional information or have any |
|
21 * questions. |
|
22 */ |
|
23 |
|
24 import java.io.IOException; |
|
25 import java.io.UncheckedIOException; |
|
26 import java.net.InetSocketAddress; |
|
27 import java.net.URI; |
|
28 import java.nio.ByteBuffer; |
|
29 import java.nio.CharBuffer; |
|
30 import java.nio.channels.ServerSocketChannel; |
|
31 import java.nio.channels.SocketChannel; |
|
32 import java.nio.charset.CharacterCodingException; |
|
33 import java.nio.charset.StandardCharsets; |
|
34 import java.security.MessageDigest; |
|
35 import java.security.NoSuchAlgorithmException; |
|
36 import java.util.Arrays; |
|
37 import java.util.Base64; |
|
38 import java.util.HashMap; |
|
39 import java.util.Iterator; |
|
40 import java.util.LinkedList; |
|
41 import java.util.List; |
|
42 import java.util.Map; |
|
43 import java.util.concurrent.CompletableFuture; |
|
44 import java.util.function.Function; |
|
45 import java.util.regex.Pattern; |
|
46 import java.util.stream.Collectors; |
|
47 |
|
48 import static java.lang.String.format; |
|
49 import static java.util.Objects.requireNonNull; |
|
50 |
|
51 // |
|
52 // Performs a simple opening handshake and yields the channel. |
|
53 // |
|
54 // Client Request: |
|
55 // |
|
56 // GET /chat HTTP/1.1 |
|
57 // Host: server.example.com |
|
58 // Upgrade: websocket |
|
59 // Connection: Upgrade |
|
60 // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== |
|
61 // Origin: http://example.com |
|
62 // Sec-WebSocket-Protocol: chat, superchat |
|
63 // Sec-WebSocket-Version: 13 |
|
64 // |
|
65 // |
|
66 // Server Response: |
|
67 // |
|
68 // HTTP/1.1 101 Switching Protocols |
|
69 // Upgrade: websocket |
|
70 // Connection: Upgrade |
|
71 // Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= |
|
72 // Sec-WebSocket-Protocol: chat |
|
73 // |
|
74 final class HandshakePhase { |
|
75 |
|
76 private final ServerSocketChannel ssc; |
|
77 |
|
78 HandshakePhase(InetSocketAddress address) { |
|
79 requireNonNull(address); |
|
80 try { |
|
81 ssc = ServerSocketChannel.open(); |
|
82 ssc.bind(address); |
|
83 } catch (IOException e) { |
|
84 throw new UncheckedIOException(e); |
|
85 } |
|
86 } |
|
87 |
|
88 // |
|
89 // Returned CF completes normally after the handshake has been performed |
|
90 // |
|
91 CompletableFuture<SocketChannel> afterHandshake( |
|
92 Function<List<String>, List<String>> mapping) { |
|
93 return CompletableFuture.supplyAsync( |
|
94 () -> { |
|
95 SocketChannel socketChannel = accept(); |
|
96 try { |
|
97 StringBuilder request = new StringBuilder(); |
|
98 if (!readRequest(socketChannel, request)) { |
|
99 throw new IllegalStateException(); |
|
100 } |
|
101 List<String> strings = Arrays.asList( |
|
102 request.toString().split("\r\n") |
|
103 ); |
|
104 List<String> response = mapping.apply(strings); |
|
105 writeResponse(socketChannel, response); |
|
106 return socketChannel; |
|
107 } catch (Throwable t) { |
|
108 try { |
|
109 socketChannel.close(); |
|
110 } catch (IOException ignored) { } |
|
111 throw t; |
|
112 } |
|
113 }); |
|
114 } |
|
115 |
|
116 CompletableFuture<SocketChannel> afterHandshake() { |
|
117 return afterHandshake((request) -> { |
|
118 List<String> response = new LinkedList<>(); |
|
119 Iterator<String> iterator = request.iterator(); |
|
120 if (!iterator.hasNext()) { |
|
121 throw new IllegalStateException("The request is empty"); |
|
122 } |
|
123 if (!"GET / HTTP/1.1".equals(iterator.next())) { |
|
124 throw new IllegalStateException |
|
125 ("Unexpected status line: " + request.get(0)); |
|
126 } |
|
127 response.add("HTTP/1.1 101 Switching Protocols"); |
|
128 Map<String, String> requestHeaders = new HashMap<>(); |
|
129 while (iterator.hasNext()) { |
|
130 String header = iterator.next(); |
|
131 String[] split = header.split(": "); |
|
132 if (split.length != 2) { |
|
133 throw new IllegalStateException |
|
134 ("Unexpected header: " + header |
|
135 + ", split=" + Arrays.toString(split)); |
|
136 } |
|
137 if (requestHeaders.put(split[0], split[1]) != null) { |
|
138 throw new IllegalStateException |
|
139 ("Duplicating headers: " + Arrays.toString(split)); |
|
140 } |
|
141 } |
|
142 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { |
|
143 throw new IllegalStateException("Subprotocols are not expected"); |
|
144 } |
|
145 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { |
|
146 throw new IllegalStateException("Extensions are not expected"); |
|
147 } |
|
148 expectHeader(requestHeaders, "Connection", "Upgrade"); |
|
149 response.add("Connection: Upgrade"); |
|
150 expectHeader(requestHeaders, "Upgrade", "websocket"); |
|
151 response.add("Upgrade: websocket"); |
|
152 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); |
|
153 String key = requestHeaders.get("Sec-WebSocket-Key"); |
|
154 if (key == null) { |
|
155 throw new IllegalStateException("Sec-WebSocket-Key is missing"); |
|
156 } |
|
157 MessageDigest sha1 = null; |
|
158 try { |
|
159 sha1 = MessageDigest.getInstance("SHA-1"); |
|
160 } catch (NoSuchAlgorithmException e) { |
|
161 throw new InternalError(e); |
|
162 } |
|
163 String x = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
|
164 sha1.update(x.getBytes(StandardCharsets.ISO_8859_1)); |
|
165 String v = Base64.getEncoder().encodeToString(sha1.digest()); |
|
166 response.add("Sec-WebSocket-Accept: " + v); |
|
167 return response; |
|
168 }); |
|
169 } |
|
170 |
|
171 private String expectHeader(Map<String, String> headers, |
|
172 String name, |
|
173 String value) { |
|
174 String v = headers.get(name); |
|
175 if (!value.equals(v)) { |
|
176 throw new IllegalStateException( |
|
177 format("Expected '%s: %s', actual: '%s: %s'", |
|
178 name, value, name, v) |
|
179 ); |
|
180 } |
|
181 return v; |
|
182 } |
|
183 |
|
184 URI getURI() { |
|
185 InetSocketAddress a; |
|
186 try { |
|
187 a = (InetSocketAddress) ssc.getLocalAddress(); |
|
188 } catch (IOException e) { |
|
189 throw new UncheckedIOException(e); |
|
190 } |
|
191 return URI.create("ws://" + a.getHostName() + ":" + a.getPort()); |
|
192 } |
|
193 |
|
194 private int read(SocketChannel socketChannel, ByteBuffer buffer) { |
|
195 try { |
|
196 int num = socketChannel.read(buffer); |
|
197 if (num == -1) { |
|
198 throw new IllegalStateException("Unexpected EOF"); |
|
199 } |
|
200 assert socketChannel.isBlocking() && num > 0; |
|
201 return num; |
|
202 } catch (IOException e) { |
|
203 throw new UncheckedIOException(e); |
|
204 } |
|
205 } |
|
206 |
|
207 private SocketChannel accept() { |
|
208 SocketChannel socketChannel = null; |
|
209 try { |
|
210 socketChannel = ssc.accept(); |
|
211 socketChannel.configureBlocking(true); |
|
212 } catch (IOException e) { |
|
213 if (socketChannel != null) { |
|
214 try { |
|
215 socketChannel.close(); |
|
216 } catch (IOException ignored) { } |
|
217 } |
|
218 throw new UncheckedIOException(e); |
|
219 } |
|
220 return socketChannel; |
|
221 } |
|
222 |
|
223 private boolean readRequest(SocketChannel socketChannel, |
|
224 StringBuilder request) { |
|
225 ByteBuffer buffer = ByteBuffer.allocateDirect(512); |
|
226 read(socketChannel, buffer); |
|
227 CharBuffer decoded; |
|
228 buffer.flip(); |
|
229 try { |
|
230 decoded = |
|
231 StandardCharsets.ISO_8859_1.newDecoder().decode(buffer); |
|
232 } catch (CharacterCodingException e) { |
|
233 throw new UncheckedIOException(e); |
|
234 } |
|
235 request.append(decoded); |
|
236 return Pattern.compile("\r\n\r\n").matcher(request).find(); |
|
237 } |
|
238 |
|
239 private void writeResponse(SocketChannel socketChannel, |
|
240 List<String> response) { |
|
241 String s = response.stream().collect(Collectors.joining("\r\n")) |
|
242 + "\r\n\r\n"; |
|
243 ByteBuffer encoded; |
|
244 try { |
|
245 encoded = |
|
246 StandardCharsets.ISO_8859_1.newEncoder().encode(CharBuffer.wrap(s)); |
|
247 } catch (CharacterCodingException e) { |
|
248 throw new UncheckedIOException(e); |
|
249 } |
|
250 write(socketChannel, encoded); |
|
251 } |
|
252 |
|
253 private void write(SocketChannel socketChannel, ByteBuffer buffer) { |
|
254 try { |
|
255 while (buffer.hasRemaining()) { |
|
256 socketChannel.write(buffer); |
|
257 } |
|
258 } catch (IOException e) { |
|
259 try { |
|
260 socketChannel.close(); |
|
261 } catch (IOException ignored) { } |
|
262 throw new UncheckedIOException(e); |
|
263 } |
|
264 } |
|
265 } |
|