32 import java.nio.channels.ServerSocketChannel; |
32 import java.nio.channels.ServerSocketChannel; |
33 import java.nio.channels.SocketChannel; |
33 import java.nio.channels.SocketChannel; |
34 import java.nio.charset.CharacterCodingException; |
34 import java.nio.charset.CharacterCodingException; |
35 import java.security.MessageDigest; |
35 import java.security.MessageDigest; |
36 import java.security.NoSuchAlgorithmException; |
36 import java.security.NoSuchAlgorithmException; |
|
37 import java.util.ArrayList; |
37 import java.util.Arrays; |
38 import java.util.Arrays; |
38 import java.util.Base64; |
39 import java.util.Base64; |
39 import java.util.HashMap; |
40 import java.util.HashMap; |
40 import java.util.Iterator; |
41 import java.util.Iterator; |
41 import java.util.LinkedList; |
42 import java.util.LinkedList; |
45 import java.util.function.Function; |
46 import java.util.function.Function; |
46 import java.util.regex.Pattern; |
47 import java.util.regex.Pattern; |
47 import java.util.stream.Collectors; |
48 import java.util.stream.Collectors; |
48 |
49 |
49 import static java.lang.String.format; |
50 import static java.lang.String.format; |
50 import static java.lang.System.Logger.Level.ERROR; |
51 import static java.lang.System.err; |
51 import static java.lang.System.Logger.Level.INFO; |
|
52 import static java.lang.System.Logger.Level.TRACE; |
|
53 import static java.nio.charset.StandardCharsets.ISO_8859_1; |
52 import static java.nio.charset.StandardCharsets.ISO_8859_1; |
54 import static java.util.Arrays.asList; |
53 import static java.util.Arrays.asList; |
55 import static java.util.Objects.requireNonNull; |
54 import static java.util.Objects.requireNonNull; |
56 |
55 |
57 /** |
56 /** |
81 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= |
80 * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= |
82 * Sec-WebSocket-Protocol: chat |
81 * Sec-WebSocket-Protocol: chat |
83 */ |
82 */ |
84 public final class DummyWebSocketServer implements Closeable { |
83 public final class DummyWebSocketServer implements Closeable { |
85 |
84 |
86 private final static System.Logger log = System.getLogger(DummyWebSocketServer.class.getName()); |
|
87 private final AtomicBoolean started = new AtomicBoolean(); |
85 private final AtomicBoolean started = new AtomicBoolean(); |
88 private final Thread thread; |
86 private final Thread thread; |
89 private volatile ServerSocketChannel ssc; |
87 private volatile ServerSocketChannel ssc; |
90 private volatile InetSocketAddress address; |
88 private volatile InetSocketAddress address; |
91 |
89 |
96 public DummyWebSocketServer(Function<List<String>, List<String>> mapping) { |
94 public DummyWebSocketServer(Function<List<String>, List<String>> mapping) { |
97 requireNonNull(mapping); |
95 requireNonNull(mapping); |
98 thread = new Thread(() -> { |
96 thread = new Thread(() -> { |
99 try { |
97 try { |
100 while (!Thread.currentThread().isInterrupted()) { |
98 while (!Thread.currentThread().isInterrupted()) { |
101 log.log(INFO, "Accepting next connection at: " + ssc); |
99 err.println("Accepting next connection at: " + ssc); |
102 SocketChannel channel = ssc.accept(); |
100 SocketChannel channel = ssc.accept(); |
103 log.log(INFO, "Accepted: " + channel); |
101 err.println("Accepted: " + channel); |
104 try { |
102 try { |
105 channel.configureBlocking(true); |
103 channel.configureBlocking(true); |
106 StringBuilder request = new StringBuilder(); |
104 StringBuilder request = new StringBuilder(); |
107 if (!readRequest(channel, request)) { |
105 if (!readRequest(channel, request)) { |
108 throw new IOException("Bad request:" + request); |
106 throw new IOException("Bad request:" + request); |
115 ByteBuffer b = ByteBuffer.allocate(1024); |
113 ByteBuffer b = ByteBuffer.allocate(1024); |
116 while (channel.read(b) != -1) { |
114 while (channel.read(b) != -1) { |
117 b.clear(); |
115 b.clear(); |
118 } |
116 } |
119 } catch (IOException e) { |
117 } catch (IOException e) { |
120 log.log(TRACE, () -> "Error in connection: " + channel, e); |
118 err.println("Error in connection: " + channel + ", " + e); |
121 } finally { |
119 } finally { |
122 log.log(INFO, "Closed: " + channel); |
120 err.println("Closed: " + channel); |
123 close(channel); |
121 close(channel); |
124 } |
122 } |
125 } |
123 } |
126 } catch (ClosedByInterruptException ignored) { |
124 } catch (ClosedByInterruptException ignored) { |
127 } catch (IOException e) { |
125 } catch (IOException e) { |
128 log.log(ERROR, e); |
126 err.println(e); |
129 } finally { |
127 } finally { |
130 close(ssc); |
128 close(ssc); |
131 log.log(INFO, "Stopped at: " + getURI()); |
129 err.println("Stopped at: " + getURI()); |
132 } |
130 } |
133 }); |
131 }); |
134 thread.setName("DummyWebSocketServer"); |
132 thread.setName("DummyWebSocketServer"); |
135 thread.setDaemon(false); |
133 thread.setDaemon(false); |
136 } |
134 } |
137 |
135 |
138 public void open() throws IOException { |
136 public void open() throws IOException { |
139 log.log(INFO, "Starting"); |
137 err.println("Starting"); |
140 if (!started.compareAndSet(false, true)) { |
138 if (!started.compareAndSet(false, true)) { |
141 throw new IllegalStateException("Already started"); |
139 throw new IllegalStateException("Already started"); |
142 } |
140 } |
143 ssc = ServerSocketChannel.open(); |
141 ssc = ServerSocketChannel.open(); |
144 try { |
142 try { |
208 List<String> response = new LinkedList<>(); |
206 List<String> response = new LinkedList<>(); |
209 Iterator<String> iterator = request.iterator(); |
207 Iterator<String> iterator = request.iterator(); |
210 if (!iterator.hasNext()) { |
208 if (!iterator.hasNext()) { |
211 throw new IllegalStateException("The request is empty"); |
209 throw new IllegalStateException("The request is empty"); |
212 } |
210 } |
213 if (!"GET / HTTP/1.1".equals(iterator.next())) { |
211 String statusLine = iterator.next(); |
|
212 if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) { |
214 throw new IllegalStateException |
213 throw new IllegalStateException |
215 ("Unexpected status line: " + request.get(0)); |
214 ("Unexpected status line: " + request.get(0)); |
216 } |
215 } |
217 response.add("HTTP/1.1 101 Switching Protocols"); |
216 response.add("HTTP/1.1 101 Switching Protocols"); |
218 Map<String, String> requestHeaders = new HashMap<>(); |
217 Map<String, List<String>> requestHeaders = new HashMap<>(); |
219 while (iterator.hasNext()) { |
218 while (iterator.hasNext()) { |
220 String header = iterator.next(); |
219 String header = iterator.next(); |
221 String[] split = header.split(": "); |
220 String[] split = header.split(": "); |
222 if (split.length != 2) { |
221 if (split.length != 2) { |
223 throw new IllegalStateException |
222 throw new IllegalStateException |
224 ("Unexpected header: " + header |
223 ("Unexpected header: " + header |
225 + ", split=" + Arrays.toString(split)); |
224 + ", split=" + Arrays.toString(split)); |
226 } |
225 } |
227 if (requestHeaders.put(split[0], split[1]) != null) { |
226 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]); |
228 throw new IllegalStateException |
227 |
229 ("Duplicating headers: " + Arrays.toString(split)); |
|
230 } |
|
231 } |
228 } |
232 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { |
229 if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) { |
233 throw new IllegalStateException("Subprotocols are not expected"); |
230 throw new IllegalStateException("Subprotocols are not expected"); |
234 } |
231 } |
235 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { |
232 if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) { |
238 expectHeader(requestHeaders, "Connection", "Upgrade"); |
235 expectHeader(requestHeaders, "Connection", "Upgrade"); |
239 response.add("Connection: Upgrade"); |
236 response.add("Connection: Upgrade"); |
240 expectHeader(requestHeaders, "Upgrade", "websocket"); |
237 expectHeader(requestHeaders, "Upgrade", "websocket"); |
241 response.add("Upgrade: websocket"); |
238 response.add("Upgrade: websocket"); |
242 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); |
239 expectHeader(requestHeaders, "Sec-WebSocket-Version", "13"); |
243 String key = requestHeaders.get("Sec-WebSocket-Key"); |
240 List<String> key = requestHeaders.get("Sec-WebSocket-Key"); |
244 if (key == null) { |
241 if (key == null || key.isEmpty()) { |
245 throw new IllegalStateException("Sec-WebSocket-Key is missing"); |
242 throw new IllegalStateException("Sec-WebSocket-Key is missing"); |
|
243 } |
|
244 if (key.size() != 1) { |
|
245 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key); |
246 } |
246 } |
247 MessageDigest sha1 = null; |
247 MessageDigest sha1 = null; |
248 try { |
248 try { |
249 sha1 = MessageDigest.getInstance("SHA-1"); |
249 sha1 = MessageDigest.getInstance("SHA-1"); |
250 } catch (NoSuchAlgorithmException e) { |
250 } catch (NoSuchAlgorithmException e) { |
251 throw new InternalError(e); |
251 throw new InternalError(e); |
252 } |
252 } |
253 String x = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
253 String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
254 sha1.update(x.getBytes(ISO_8859_1)); |
254 sha1.update(x.getBytes(ISO_8859_1)); |
255 String v = Base64.getEncoder().encodeToString(sha1.digest()); |
255 String v = Base64.getEncoder().encodeToString(sha1.digest()); |
256 response.add("Sec-WebSocket-Accept: " + v); |
256 response.add("Sec-WebSocket-Accept: " + v); |
257 return response; |
257 return response; |
258 }; |
258 }; |
259 } |
259 } |
260 |
260 |
261 protected static String expectHeader(Map<String, String> headers, |
261 protected static String expectHeader(Map<String, List<String>> headers, |
262 String name, |
262 String name, |
263 String value) { |
263 String value) { |
264 String v = headers.get(name); |
264 List<String> v = headers.get(name); |
265 if (!value.equals(v)) { |
265 if (!v.contains(value)) { |
266 throw new IllegalStateException( |
266 throw new IllegalStateException( |
267 format("Expected '%s: %s', actual: '%s: %s'", |
267 format("Expected '%s: %s', actual: '%s: %s'", |
268 name, value, name, v) |
268 name, value, name, v) |
269 ); |
269 ); |
270 } |
270 } |
271 return v; |
271 return value; |
272 } |
272 } |
273 |
273 |
274 private static void close(AutoCloseable... acs) { |
274 private static void close(AutoCloseable... acs) { |
275 for (AutoCloseable ac : acs) { |
275 for (AutoCloseable ac : acs) { |
276 try { |
276 try { |