jdk/src/java.httpclient/share/classes/java/net/http/WSFrame.java
changeset 37874 02589df0999a
child 39730 196f4e25d9f5
equal deleted inserted replaced
37858:7c04fcb12bd4 37874:02589df0999a
       
     1 /*
       
     2  * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
       
     3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
       
     4  *
       
     5  * This code is free software; you can redistribute it and/or modify it
       
     6  * under the terms of the GNU General  License version 2 only, as
       
     7  * published by the Free Software Foundation.  Oracle designates this
       
     8  * particular file as subject to the "Classpath" exception as provided
       
     9  * by Oracle in the LICENSE file that accompanied this code.
       
    10  *
       
    11  * This code is distributed in the hope that it will be useful, but WITHOUT
       
    12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
       
    13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General  License
       
    14  * version 2 for more details (a copy is included in the LICENSE file that
       
    15  * accompanied this code).
       
    16  *
       
    17  * You should have received a copy of the GNU General  License version
       
    18  * 2 along with this work; if not, write to the Free Software Foundation,
       
    19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
       
    20  *
       
    21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
       
    22  * or visit www.oracle.com if you need additional information or have any
       
    23  * questions.
       
    24  */
       
    25 package java.net.http;
       
    26 
       
    27 import java.nio.ByteBuffer;
       
    28 
       
    29 import static java.lang.String.format;
       
    30 import static java.net.http.WSFrame.Opcode.ofCode;
       
    31 import static java.net.http.WSUtils.dump;
       
    32 
       
    33 /*
       
    34  * A collection of utilities for reading, writing, and masking frames.
       
    35  */
       
    36 final class WSFrame {
       
    37 
       
    38     private WSFrame() { }
       
    39 
       
    40     static final int MAX_HEADER_SIZE_BYTES = 2 + 8 + 4;
       
    41 
       
    42     enum Opcode {
       
    43 
       
    44         CONTINUATION   (0x0),
       
    45         TEXT           (0x1),
       
    46         BINARY         (0x2),
       
    47         NON_CONTROL_0x3(0x3),
       
    48         NON_CONTROL_0x4(0x4),
       
    49         NON_CONTROL_0x5(0x5),
       
    50         NON_CONTROL_0x6(0x6),
       
    51         NON_CONTROL_0x7(0x7),
       
    52         CLOSE          (0x8),
       
    53         PING           (0x9),
       
    54         PONG           (0xA),
       
    55         CONTROL_0xB    (0xB),
       
    56         CONTROL_0xC    (0xC),
       
    57         CONTROL_0xD    (0xD),
       
    58         CONTROL_0xE    (0xE),
       
    59         CONTROL_0xF    (0xF);
       
    60 
       
    61         private static final Opcode[] opcodes;
       
    62 
       
    63         static {
       
    64             Opcode[] values = values();
       
    65             opcodes = new Opcode[values.length];
       
    66             for (Opcode c : values) {
       
    67                 assert opcodes[c.code] == null
       
    68                         : WSUtils.dump(c, c.code, opcodes[c.code]);
       
    69                 opcodes[c.code] = c;
       
    70             }
       
    71         }
       
    72 
       
    73         private final byte code;
       
    74         private final char shiftedCode;
       
    75         private final String description;
       
    76 
       
    77         Opcode(int code) {
       
    78             this.code = (byte) code;
       
    79             this.shiftedCode = (char) (code << 8);
       
    80             this.description = format("%x (%s)", code, name());
       
    81         }
       
    82 
       
    83         boolean isControl() {
       
    84             return (code & 0x8) != 0;
       
    85         }
       
    86 
       
    87         static Opcode ofCode(int code) {
       
    88             return opcodes[code & 0xF];
       
    89         }
       
    90 
       
    91         @Override
       
    92         public String toString() {
       
    93             return description;
       
    94         }
       
    95     }
       
    96 
       
    97     /*
       
    98      * A utility to mask payload data.
       
    99      */
       
   100     static final class Masker {
       
   101 
       
   102         private final ByteBuffer acc = ByteBuffer.allocate(8);
       
   103         private final int[] maskBytes = new int[4];
       
   104         private int offset;
       
   105         private long maskLong;
       
   106 
       
   107         /*
       
   108          * Sets up the mask.
       
   109          */
       
   110         Masker mask(int value) {
       
   111             acc.clear().putInt(value).putInt(value).flip();
       
   112             for (int i = 0; i < maskBytes.length; i++) {
       
   113                 maskBytes[i] = acc.get(i);
       
   114             }
       
   115             offset = 0;
       
   116             maskLong = acc.getLong(0);
       
   117             return this;
       
   118         }
       
   119 
       
   120         /*
       
   121          * Reads as many bytes as possible from the given input buffer, writing
       
   122          * the resulting masked bytes to the given output buffer.
       
   123          *
       
   124          * src.remaining() <= dst.remaining() // TODO: do we need this restriction?
       
   125          * 'src' and 'dst' can be the same ByteBuffer
       
   126          */
       
   127         Masker applyMask(ByteBuffer src, ByteBuffer dst) {
       
   128             if (src.remaining() > dst.remaining()) {
       
   129                 throw new IllegalArgumentException(dump(src, dst));
       
   130             }
       
   131             begin(src, dst);
       
   132             loop(src, dst);
       
   133             end(src, dst);
       
   134             return this;
       
   135         }
       
   136 
       
   137         // Applying the remaining of the mask (strictly not more than 3 bytes)
       
   138         // byte-wise
       
   139         private void begin(ByteBuffer src, ByteBuffer dst) {
       
   140             if (offset > 0) {
       
   141                 for (int i = src.position(), j = dst.position();
       
   142                      offset < 4 && i <= src.limit() - 1 && j <= dst.limit() - 1;
       
   143                      i++, j++, offset++) {
       
   144                     dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
       
   145                     dst.position(j + 1);
       
   146                     src.position(i + 1);
       
   147                 }
       
   148                 offset &= 3;
       
   149             }
       
   150         }
       
   151 
       
   152         private void loop(ByteBuffer src, ByteBuffer dst) {
       
   153             int i = src.position();
       
   154             int j = dst.position();
       
   155             final int srcLim = src.limit() - 8;
       
   156             final int dstLim = dst.limit() - 8;
       
   157             for (; i <= srcLim && j <= dstLim; i += 8, j += 8) {
       
   158                 dst.putLong(j, (src.getLong(i) ^ maskLong));
       
   159             }
       
   160             if (i > src.limit()) {
       
   161                 src.position(i - 8);
       
   162             } else {
       
   163                 src.position(i);
       
   164             }
       
   165             if (j > dst.limit()) {
       
   166                 dst.position(j - 8);
       
   167             } else {
       
   168                 dst.position(j);
       
   169             }
       
   170         }
       
   171 
       
   172         // Applying the mask to the remaining bytes byte-wise (don't make any
       
   173         // assumptions on how many, hopefully not more than 7 for 64bit arch)
       
   174         private void end(ByteBuffer src, ByteBuffer dst) {
       
   175             for (int i = src.position(), j = dst.position();
       
   176                  i <= src.limit() - 1 && j <= dst.limit() - 1;
       
   177                  i++, j++, offset = (offset + 1) & 3) { // offset cycle through 0..3
       
   178                 dst.put(j, (byte) (src.get(i) ^ maskBytes[offset]));
       
   179                 src.position(i + 1);
       
   180                 dst.position(j + 1);
       
   181             }
       
   182         }
       
   183     }
       
   184 
       
   185     /*
       
   186      * A builder of frame headers, capable of writing to a given buffer.
       
   187      *
       
   188      * The builder does not enforce any protocol-level rules, it simply writes
       
   189      * a header structure to the buffer. The order of calls to intermediate
       
   190      * methods is not significant.
       
   191      */
       
   192     static final class HeaderBuilder {
       
   193 
       
   194         private char firstChar;
       
   195         private long payloadLen;
       
   196         private int maskingKey;
       
   197         private boolean mask;
       
   198 
       
   199         HeaderBuilder fin(boolean value) {
       
   200             if (value) {
       
   201                 firstChar |=  0b10000000_00000000;
       
   202             } else {
       
   203                 firstChar &= ~0b10000000_00000000;
       
   204             }
       
   205             return this;
       
   206         }
       
   207 
       
   208         HeaderBuilder rsv1(boolean value) {
       
   209             if (value) {
       
   210                 firstChar |=  0b01000000_00000000;
       
   211             } else {
       
   212                 firstChar &= ~0b01000000_00000000;
       
   213             }
       
   214             return this;
       
   215         }
       
   216 
       
   217         HeaderBuilder rsv2(boolean value) {
       
   218             if (value) {
       
   219                 firstChar |=  0b00100000_00000000;
       
   220             } else {
       
   221                 firstChar &= ~0b00100000_00000000;
       
   222             }
       
   223             return this;
       
   224         }
       
   225 
       
   226         HeaderBuilder rsv3(boolean value) {
       
   227             if (value) {
       
   228                 firstChar |=  0b00010000_00000000;
       
   229             } else {
       
   230                 firstChar &= ~0b00010000_00000000;
       
   231             }
       
   232             return this;
       
   233         }
       
   234 
       
   235         HeaderBuilder opcode(Opcode value) {
       
   236             firstChar = (char) ((firstChar & 0xF0FF) | value.shiftedCode);
       
   237             return this;
       
   238         }
       
   239 
       
   240         HeaderBuilder payloadLen(long value) {
       
   241             payloadLen = value;
       
   242             firstChar &= 0b11111111_10000000; // Clear previous payload length leftovers
       
   243             if (payloadLen < 126) {
       
   244                 firstChar |= payloadLen;
       
   245             } else if (payloadLen < 65535) {
       
   246                 firstChar |= 126;
       
   247             } else {
       
   248                 firstChar |= 127;
       
   249             }
       
   250             return this;
       
   251         }
       
   252 
       
   253         HeaderBuilder mask(int value) {
       
   254             firstChar |= 0b00000000_10000000;
       
   255             maskingKey = value;
       
   256             mask = true;
       
   257             return this;
       
   258         }
       
   259 
       
   260         HeaderBuilder noMask() {
       
   261             firstChar &= ~0b00000000_10000000;
       
   262             mask = false;
       
   263             return this;
       
   264         }
       
   265 
       
   266         /*
       
   267          * Writes the header to the given buffer.
       
   268          *
       
   269          * The buffer must have at least MAX_HEADER_SIZE_BYTES remaining. The
       
   270          * buffer's position is incremented by the number of bytes written.
       
   271          */
       
   272         void build(ByteBuffer buffer) {
       
   273             buffer.putChar(firstChar);
       
   274             if (payloadLen >= 126) {
       
   275                 if (payloadLen < 65535) {
       
   276                     buffer.putChar((char) payloadLen);
       
   277                 } else {
       
   278                     buffer.putLong(payloadLen);
       
   279                 }
       
   280             }
       
   281             if (mask) {
       
   282                 buffer.putInt(maskingKey);
       
   283             }
       
   284         }
       
   285     }
       
   286 
       
   287     /*
       
   288      * A consumer of frame parts.
       
   289      *
       
   290      * Guaranteed to be called in the following order by the Frame.Reader:
       
   291      *
       
   292      *     fin rsv1 rsv2 rsv3 opcode mask payloadLength maskingKey? payloadData+ endFrame
       
   293      */
       
   294     interface Consumer {
       
   295 
       
   296         void fin(boolean value);
       
   297 
       
   298         void rsv1(boolean value);
       
   299 
       
   300         void rsv2(boolean value);
       
   301 
       
   302         void rsv3(boolean value);
       
   303 
       
   304         void opcode(Opcode value);
       
   305 
       
   306         void mask(boolean value);
       
   307 
       
   308         void payloadLen(long value);
       
   309 
       
   310         void maskingKey(int value);
       
   311 
       
   312         /*
       
   313          * Called when a part of the payload is ready to be consumed.
       
   314          *
       
   315          * Though may not yield a complete payload in a single invocation, i.e.
       
   316          *
       
   317          *     data.remaining() < payloadLen
       
   318          *
       
   319          * the sum of `data.remaining()` passed to all invocations of this
       
   320          * method will be equal to 'payloadLen', reported in
       
   321          * `void payloadLen(long value)`
       
   322          *
       
   323          * No unmasking is done.
       
   324          */
       
   325         void payloadData(WSShared<ByteBuffer> data, boolean isLast);
       
   326 
       
   327         void endFrame(); // TODO: remove (payloadData(isLast=true)) should be enough
       
   328     }
       
   329 
       
   330     /*
       
   331      * A Reader of Frames.
       
   332      *
       
   333      * No protocol-level rules are enforced, only frame structure.
       
   334      */
       
   335     static final class Reader {
       
   336 
       
   337         private static final int AWAITING_FIRST_BYTE  =  1;
       
   338         private static final int AWAITING_SECOND_BYTE =  2;
       
   339         private static final int READING_16_LENGTH    =  4;
       
   340         private static final int READING_64_LENGTH    =  8;
       
   341         private static final int READING_MASK         = 16;
       
   342         private static final int READING_PAYLOAD      = 32;
       
   343 
       
   344         // A private buffer used to simplify multi-byte integers reading
       
   345         private final ByteBuffer accumulator = ByteBuffer.allocate(8);
       
   346         private int state = AWAITING_FIRST_BYTE;
       
   347         private boolean mask;
       
   348         private long payloadLength;
       
   349 
       
   350         /*
       
   351          * Reads at most one frame from the given buffer invoking the consumer's
       
   352          * methods corresponding to the frame elements found.
       
   353          *
       
   354          * As much of the frame's payload, if any, is read. The buffers position
       
   355          * is updated to reflect the number of bytes read.
       
   356          *
       
   357          * Throws WSProtocolException if the frame is malformed.
       
   358          */
       
   359         void readFrame(WSShared<ByteBuffer> shared, Consumer consumer) {
       
   360             ByteBuffer input = shared.buffer();
       
   361             loop:
       
   362             while (true) {
       
   363                 byte b;
       
   364                 switch (state) {
       
   365                     case AWAITING_FIRST_BYTE:
       
   366                         if (!input.hasRemaining()) {
       
   367                             break loop;
       
   368                         }
       
   369                         b = input.get();
       
   370                         consumer.fin( (b & 0b10000000) != 0);
       
   371                         consumer.rsv1((b & 0b01000000) != 0);
       
   372                         consumer.rsv2((b & 0b00100000) != 0);
       
   373                         consumer.rsv3((b & 0b00010000) != 0);
       
   374                         consumer.opcode(ofCode(b));
       
   375                         state = AWAITING_SECOND_BYTE;
       
   376                         continue loop;
       
   377                     case AWAITING_SECOND_BYTE:
       
   378                         if (!input.hasRemaining()) {
       
   379                             break loop;
       
   380                         }
       
   381                         b = input.get();
       
   382                         consumer.mask(mask = (b & 0b10000000) != 0);
       
   383                         byte p1 = (byte) (b & 0b01111111);
       
   384                         if (p1 < 126) {
       
   385                             assert p1 >= 0 : p1;
       
   386                             consumer.payloadLen(payloadLength = p1);
       
   387                             state = mask ? READING_MASK : READING_PAYLOAD;
       
   388                         } else if (p1 < 127) {
       
   389                             state = READING_16_LENGTH;
       
   390                         } else {
       
   391                             state = READING_64_LENGTH;
       
   392                         }
       
   393                         continue loop;
       
   394                     case READING_16_LENGTH:
       
   395                         if (!input.hasRemaining()) {
       
   396                             break loop;
       
   397                         }
       
   398                         b = input.get();
       
   399                         if (accumulator.put(b).position() < 2) {
       
   400                             continue loop;
       
   401                         }
       
   402                         payloadLength = accumulator.flip().getChar();
       
   403                         if (payloadLength < 126) {
       
   404                             throw notMinimalEncoding(payloadLength, 2);
       
   405                         }
       
   406                         consumer.payloadLen(payloadLength);
       
   407                         accumulator.clear();
       
   408                         state = mask ? READING_MASK : READING_PAYLOAD;
       
   409                         continue loop;
       
   410                     case READING_64_LENGTH:
       
   411                         if (!input.hasRemaining()) {
       
   412                             break loop;
       
   413                         }
       
   414                         b = input.get();
       
   415                         if (accumulator.put(b).position() < 8) {
       
   416                             continue loop;
       
   417                         }
       
   418                         payloadLength = accumulator.flip().getLong();
       
   419                         if (payloadLength < 0) {
       
   420                             throw negativePayload(payloadLength);
       
   421                         } else if (payloadLength < 65535) {
       
   422                             throw notMinimalEncoding(payloadLength, 8);
       
   423                         }
       
   424                         consumer.payloadLen(payloadLength);
       
   425                         accumulator.clear();
       
   426                         state = mask ? READING_MASK : READING_PAYLOAD;
       
   427                         continue loop;
       
   428                     case READING_MASK:
       
   429                         if (!input.hasRemaining()) {
       
   430                             break loop;
       
   431                         }
       
   432                         b = input.get();
       
   433                         if (accumulator.put(b).position() != 4) {
       
   434                             continue loop;
       
   435                         }
       
   436                         consumer.maskingKey(accumulator.flip().getInt());
       
   437                         accumulator.clear();
       
   438                         state = READING_PAYLOAD;
       
   439                         continue loop;
       
   440                     case READING_PAYLOAD:
       
   441                         // This state does not require any bytes to be available
       
   442                         // in the input buffer in order to proceed
       
   443                         boolean fullyRead;
       
   444                         int limit;
       
   445                         if (payloadLength <= input.remaining()) {
       
   446                             limit = input.position() + (int) payloadLength;
       
   447                             payloadLength = 0;
       
   448                             fullyRead = true;
       
   449                         } else {
       
   450                             limit = input.limit();
       
   451                             payloadLength -= input.remaining();
       
   452                             fullyRead = false;
       
   453                         }
       
   454                         // FIXME: consider a case where payloadLen != 0,
       
   455                         // but input.remaining() == 0
       
   456                         //
       
   457                         // There shouldn't be an invocation of payloadData with
       
   458                         // an empty buffer, as it would be an artifact of
       
   459                         // reading
       
   460                         consumer.payloadData(shared.share(input.position(), limit), fullyRead);
       
   461                         // Update the position manually, since reading the
       
   462                         // payload doesn't advance buffer's position
       
   463                         input.position(limit);
       
   464                         if (fullyRead) {
       
   465                             consumer.endFrame();
       
   466                             state = AWAITING_FIRST_BYTE;
       
   467                         }
       
   468                         break loop;
       
   469                     default:
       
   470                         throw new InternalError(String.valueOf(state));
       
   471                 }
       
   472             }
       
   473         }
       
   474 
       
   475         private static WSProtocolException negativePayload(long payloadLength) {
       
   476             return new WSProtocolException
       
   477                     ("5.2.", format("Negative 64-bit payload length %s", payloadLength));
       
   478         }
       
   479 
       
   480         private static WSProtocolException notMinimalEncoding(long payloadLength, int numBytes) {
       
   481             return new WSProtocolException
       
   482                     ("5.2.", format("Payload length (%s) is not encoded with minimal number (%s) of bytes",
       
   483                             payloadLength, numBytes));
       
   484         }
       
   485     }
       
   486 }