test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java
branchhttp-client-branch
changeset 55764 34d7cc00f87a
parent 47266 b841be61b9d9
child 49765 ee6f7a61f3a5
child 56045 5c6e3b76d2ad
equal deleted inserted replaced
55763:634d8e14c172 55764:34d7cc00f87a
    32 import java.nio.channels.ServerSocketChannel;
    32 import java.nio.channels.ServerSocketChannel;
    33 import java.nio.channels.SocketChannel;
    33 import java.nio.channels.SocketChannel;
    34 import java.nio.charset.CharacterCodingException;
    34 import java.nio.charset.CharacterCodingException;
    35 import java.security.MessageDigest;
    35 import java.security.MessageDigest;
    36 import java.security.NoSuchAlgorithmException;
    36 import java.security.NoSuchAlgorithmException;
       
    37 import java.util.ArrayList;
    37 import java.util.Arrays;
    38 import java.util.Arrays;
    38 import java.util.Base64;
    39 import java.util.Base64;
    39 import java.util.HashMap;
    40 import java.util.HashMap;
    40 import java.util.Iterator;
    41 import java.util.Iterator;
    41 import java.util.LinkedList;
    42 import java.util.LinkedList;
    45 import java.util.function.Function;
    46 import java.util.function.Function;
    46 import java.util.regex.Pattern;
    47 import java.util.regex.Pattern;
    47 import java.util.stream.Collectors;
    48 import java.util.stream.Collectors;
    48 
    49 
    49 import static java.lang.String.format;
    50 import static java.lang.String.format;
    50 import static java.lang.System.Logger.Level.ERROR;
    51 import static java.lang.System.err;
    51 import static java.lang.System.Logger.Level.INFO;
       
    52 import static java.lang.System.Logger.Level.TRACE;
       
    53 import static java.nio.charset.StandardCharsets.ISO_8859_1;
    52 import static java.nio.charset.StandardCharsets.ISO_8859_1;
    54 import static java.util.Arrays.asList;
    53 import static java.util.Arrays.asList;
    55 import static java.util.Objects.requireNonNull;
    54 import static java.util.Objects.requireNonNull;
    56 
    55 
    57 /**
    56 /**
    81  *     Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
    80  *     Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
    82  *     Sec-WebSocket-Protocol: chat
    81  *     Sec-WebSocket-Protocol: chat
    83  */
    82  */
    84 public final class DummyWebSocketServer implements Closeable {
    83 public final class DummyWebSocketServer implements Closeable {
    85 
    84 
    86     private final static System.Logger log = System.getLogger(DummyWebSocketServer.class.getName());
       
    87     private final AtomicBoolean started = new AtomicBoolean();
    85     private final AtomicBoolean started = new AtomicBoolean();
    88     private final Thread thread;
    86     private final Thread thread;
    89     private volatile ServerSocketChannel ssc;
    87     private volatile ServerSocketChannel ssc;
    90     private volatile InetSocketAddress address;
    88     private volatile InetSocketAddress address;
    91 
    89 
    96     public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
    94     public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
    97         requireNonNull(mapping);
    95         requireNonNull(mapping);
    98         thread = new Thread(() -> {
    96         thread = new Thread(() -> {
    99             try {
    97             try {
   100                 while (!Thread.currentThread().isInterrupted()) {
    98                 while (!Thread.currentThread().isInterrupted()) {
   101                     log.log(INFO, "Accepting next connection at: " + ssc);
    99                     err.println("Accepting next connection at: " + ssc);
   102                     SocketChannel channel = ssc.accept();
   100                     SocketChannel channel = ssc.accept();
   103                     log.log(INFO, "Accepted: " + channel);
   101                     err.println("Accepted: " + channel);
   104                     try {
   102                     try {
   105                         channel.configureBlocking(true);
   103                         channel.configureBlocking(true);
   106                         StringBuilder request = new StringBuilder();
   104                         StringBuilder request = new StringBuilder();
   107                         if (!readRequest(channel, request)) {
   105                         if (!readRequest(channel, request)) {
   108                             throw new IOException("Bad request:" + request);
   106                             throw new IOException("Bad request:" + request);
   115                         ByteBuffer b = ByteBuffer.allocate(1024);
   113                         ByteBuffer b = ByteBuffer.allocate(1024);
   116                         while (channel.read(b) != -1) {
   114                         while (channel.read(b) != -1) {
   117                             b.clear();
   115                             b.clear();
   118                         }
   116                         }
   119                     } catch (IOException e) {
   117                     } catch (IOException e) {
   120                         log.log(TRACE, () -> "Error in connection: " + channel, e);
   118                         err.println("Error in connection: " + channel + ", " + e);
   121                     } finally {
   119                     } finally {
   122                         log.log(INFO, "Closed: " + channel);
   120                         err.println("Closed: " + channel);
   123                         close(channel);
   121                         close(channel);
   124                     }
   122                     }
   125                 }
   123                 }
   126             } catch (ClosedByInterruptException ignored) {
   124             } catch (ClosedByInterruptException ignored) {
   127             } catch (IOException e) {
   125             } catch (IOException e) {
   128                 log.log(ERROR, e);
   126                 err.println(e);
   129             } finally {
   127             } finally {
   130                 close(ssc);
   128                 close(ssc);
   131                 log.log(INFO, "Stopped at: " + getURI());
   129                 err.println("Stopped at: " + getURI());
   132             }
   130             }
   133         });
   131         });
   134         thread.setName("DummyWebSocketServer");
   132         thread.setName("DummyWebSocketServer");
   135         thread.setDaemon(false);
   133         thread.setDaemon(false);
   136     }
   134     }
   137 
   135 
   138     public void open() throws IOException {
   136     public void open() throws IOException {
   139         log.log(INFO, "Starting");
   137         err.println("Starting");
   140         if (!started.compareAndSet(false, true)) {
   138         if (!started.compareAndSet(false, true)) {
   141             throw new IllegalStateException("Already started");
   139             throw new IllegalStateException("Already started");
   142         }
   140         }
   143         ssc = ServerSocketChannel.open();
   141         ssc = ServerSocketChannel.open();
   144         try {
   142         try {
   147             address = (InetSocketAddress) ssc.getLocalAddress();
   145             address = (InetSocketAddress) ssc.getLocalAddress();
   148             thread.start();
   146             thread.start();
   149         } catch (IOException e) {
   147         } catch (IOException e) {
   150             close(ssc);
   148             close(ssc);
   151         }
   149         }
   152         log.log(INFO, "Started at: " + getURI());
   150         err.println("Started at: " + getURI());
   153     }
   151     }
   154 
   152 
   155     @Override
   153     @Override
   156     public void close() {
   154     public void close() {
   157         log.log(INFO, "Stopping: " + getURI());
   155         err.println("Stopping: " + getURI());
   158         thread.interrupt();
   156         thread.interrupt();
   159         close(ssc);
   157         close(ssc);
   160     }
   158     }
   161 
   159 
   162     URI getURI() {
   160     URI getURI() {
   208             List<String> response = new LinkedList<>();
   206             List<String> response = new LinkedList<>();
   209             Iterator<String> iterator = request.iterator();
   207             Iterator<String> iterator = request.iterator();
   210             if (!iterator.hasNext()) {
   208             if (!iterator.hasNext()) {
   211                 throw new IllegalStateException("The request is empty");
   209                 throw new IllegalStateException("The request is empty");
   212             }
   210             }
   213             if (!"GET / HTTP/1.1".equals(iterator.next())) {
   211             String statusLine = iterator.next();
       
   212             if (!(statusLine.startsWith("GET /") && statusLine.endsWith(" HTTP/1.1"))) {
   214                 throw new IllegalStateException
   213                 throw new IllegalStateException
   215                         ("Unexpected status line: " + request.get(0));
   214                         ("Unexpected status line: " + request.get(0));
   216             }
   215             }
   217             response.add("HTTP/1.1 101 Switching Protocols");
   216             response.add("HTTP/1.1 101 Switching Protocols");
   218             Map<String, String> requestHeaders = new HashMap<>();
   217             Map<String, List<String>> requestHeaders = new HashMap<>();
   219             while (iterator.hasNext()) {
   218             while (iterator.hasNext()) {
   220                 String header = iterator.next();
   219                 String header = iterator.next();
   221                 String[] split = header.split(": ");
   220                 String[] split = header.split(": ");
   222                 if (split.length != 2) {
   221                 if (split.length != 2) {
   223                     throw new IllegalStateException
   222                     throw new IllegalStateException
   224                             ("Unexpected header: " + header
   223                             ("Unexpected header: " + header
   225                                      + ", split=" + Arrays.toString(split));
   224                                      + ", split=" + Arrays.toString(split));
   226                 }
   225                 }
   227                 if (requestHeaders.put(split[0], split[1]) != null) {
   226                 requestHeaders.computeIfAbsent(split[0], k -> new ArrayList<>()).add(split[1]);
   228                     throw new IllegalStateException
   227 
   229                             ("Duplicating headers: " + Arrays.toString(split));
       
   230                 }
       
   231             }
   228             }
   232             if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
   229             if (requestHeaders.containsKey("Sec-WebSocket-Protocol")) {
   233                 throw new IllegalStateException("Subprotocols are not expected");
   230                 throw new IllegalStateException("Subprotocols are not expected");
   234             }
   231             }
   235             if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
   232             if (requestHeaders.containsKey("Sec-WebSocket-Extensions")) {
   238             expectHeader(requestHeaders, "Connection", "Upgrade");
   235             expectHeader(requestHeaders, "Connection", "Upgrade");
   239             response.add("Connection: Upgrade");
   236             response.add("Connection: Upgrade");
   240             expectHeader(requestHeaders, "Upgrade", "websocket");
   237             expectHeader(requestHeaders, "Upgrade", "websocket");
   241             response.add("Upgrade: websocket");
   238             response.add("Upgrade: websocket");
   242             expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
   239             expectHeader(requestHeaders, "Sec-WebSocket-Version", "13");
   243             String key = requestHeaders.get("Sec-WebSocket-Key");
   240             List<String> key = requestHeaders.get("Sec-WebSocket-Key");
   244             if (key == null) {
   241             if (key == null || key.isEmpty()) {
   245                 throw new IllegalStateException("Sec-WebSocket-Key is missing");
   242                 throw new IllegalStateException("Sec-WebSocket-Key is missing");
       
   243             }
       
   244             if (key.size() != 1) {
       
   245                 throw new IllegalStateException("Sec-WebSocket-Key has too many values : " + key);
   246             }
   246             }
   247             MessageDigest sha1 = null;
   247             MessageDigest sha1 = null;
   248             try {
   248             try {
   249                 sha1 = MessageDigest.getInstance("SHA-1");
   249                 sha1 = MessageDigest.getInstance("SHA-1");
   250             } catch (NoSuchAlgorithmException e) {
   250             } catch (NoSuchAlgorithmException e) {
   251                 throw new InternalError(e);
   251                 throw new InternalError(e);
   252             }
   252             }
   253             String x = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
   253             String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
   254             sha1.update(x.getBytes(ISO_8859_1));
   254             sha1.update(x.getBytes(ISO_8859_1));
   255             String v = Base64.getEncoder().encodeToString(sha1.digest());
   255             String v = Base64.getEncoder().encodeToString(sha1.digest());
   256             response.add("Sec-WebSocket-Accept: " + v);
   256             response.add("Sec-WebSocket-Accept: " + v);
   257             return response;
   257             return response;
   258         };
   258         };
   259     }
   259     }
   260 
   260 
   261     protected static String expectHeader(Map<String, String> headers,
   261     protected static String expectHeader(Map<String, List<String>> headers,
   262                                          String name,
   262                                          String name,
   263                                          String value) {
   263                                          String value) {
   264         String v = headers.get(name);
   264         List<String> v = headers.get(name);
   265         if (!value.equals(v)) {
   265         if (!v.contains(value)) {
   266             throw new IllegalStateException(
   266             throw new IllegalStateException(
   267                     format("Expected '%s: %s', actual: '%s: %s'",
   267                     format("Expected '%s: %s', actual: '%s: %s'",
   268                            name, value, name, v)
   268                            name, value, name, v)
   269             );
   269             );
   270         }
   270         }
   271         return v;
   271         return value;
   272     }
   272     }
   273 
   273 
   274     private static void close(AutoCloseable... acs) {
   274     private static void close(AutoCloseable... acs) {
   275         for (AutoCloseable ac : acs) {
   275         for (AutoCloseable ac : acs) {
   276             try {
   276             try {