read side key limits JDK-8145252-TLS13-branch
authorascarpino
Tue, 19 Jun 2018 15:53:35 -0700
branchJDK-8145252-TLS13-branch
changeset 56784 6210466cf1ac
parent 56782 b472b5917a1b
child 56794 1cc2f6afa943
read side key limits
src/java.base/share/classes/sun/security/ssl/InputRecord.java
src/java.base/share/classes/sun/security/ssl/KeyUpdate.java
src/java.base/share/classes/sun/security/ssl/OutputRecord.java
src/java.base/share/classes/sun/security/ssl/SSLCipher.java
src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java
src/java.base/share/classes/sun/security/ssl/SSLEngineOutputRecord.java
src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java
src/java.base/share/classes/sun/security/ssl/SSLSocketOutputRecord.java
src/java.base/share/classes/sun/security/ssl/TransportContext.java
test/jdk/java/net/httpclient/MockServer.java
test/jdk/sun/security/ssl/SSLEngineImpl/SSLEngineKeyLimit.java
test/jdk/sun/security/ssl/SSLSocketImpl/SSLSocketKeyLimit.java
--- a/src/java.base/share/classes/sun/security/ssl/InputRecord.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/InputRecord.java	Tue Jun 19 15:53:35 2018 -0700
@@ -42,6 +42,8 @@
  */
 abstract class InputRecord implements Record, Closeable {
     SSLReadCipher       readCipher;
+    // Needed for KeyUpdate, used after Handshake.Finished
+    TransportContext            tc;
 
     final HandshakeHash handshakeHash;
     boolean             isClosed;
--- a/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/KeyUpdate.java	Tue Jun 19 15:53:35 2018 -0700
@@ -238,6 +238,7 @@
                 // Update the write key and IV.
                 handshakeProducer.produce(hc,
                     new KeyUpdateMessage(hc, KeyUpdateRequest.NOTREQUESTED));
+                return;
             }
 
             // clean handshake context
--- a/src/java.base/share/classes/sun/security/ssl/OutputRecord.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/OutputRecord.java	Tue Jun 19 15:53:35 2018 -0700
@@ -42,7 +42,7 @@
 abstract class OutputRecord
         extends ByteArrayOutputStream implements Record, Closeable {
     SSLWriteCipher              writeCipher;
-    // Needed for KeyUpdate
+    // Needed for KeyUpdate, used after Handshake.Finished
     TransportContext            tc;
 
     final HandshakeHash         handshakeHash;
--- a/src/java.base/share/classes/sun/security/ssl/SSLCipher.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLCipher.java	Tue Jun 19 15:53:35 2018 -0700
@@ -564,6 +564,8 @@
     abstract static class SSLReadCipher {
         final Authenticator authenticator;
         final ProtocolVersion protocolVersion;
+        boolean keyLimitEnabled = false;
+        long keyLimitCountdown = 0;
         SecretKey baseSecret;
 
         SSLReadCipher(Authenticator authenticator,
@@ -606,6 +608,20 @@
         boolean isNullCipher() {
             return false;
         }
+
+        /**
+         * Check if processed bytes have reached the key usage limit.
+         * If key usage limit is not be monitored, return false.
+         */
+        public boolean atKeyLimit() {
+            if (keyLimitCountdown >= 0) {
+                return false;
+            }
+
+            // Turn off limit checking as KeyUpdate will be occurring
+            keyLimitEnabled = false;
+            return true;
+        }
     }
 
     interface WriteCipherGenerator {
@@ -1801,6 +1817,16 @@
                 this.iv = ((IvParameterSpec)params).getIV();
                 this.random = random;
 
+                keyLimitCountdown = cipherLimits.getOrDefault(
+                        algorithm.toUpperCase() + ":" + tag[0], 0L);
+                if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                    SSLLogger.fine("KeyLimit read side: algorithm = " +
+                            algorithm.toUpperCase() + ":" + tag[0] +
+                            "\ncountdown value = " + keyLimitCountdown);
+                }
+                if (keyLimitCountdown > 0) {
+                    keyLimitEnabled = true;
+                }
                 // DON'T initialize the cipher for AEAD!
             }
 
@@ -1888,6 +1914,9 @@
                     SSLLogger.fine(
                             "Plaintext after DECRYPTION", bb.duplicate());
                 }
+                if (keyLimitEnabled) {
+                    keyLimitCountdown -= len;
+                }
 
                 return new Plaintext(contentType,
                         ProtocolVersion.NONE.major, ProtocolVersion.NONE.minor,
@@ -1945,9 +1974,9 @@
                 keyLimitCountdown = cipherLimits.getOrDefault(
                         algorithm.toUpperCase() + ":" + tag[0], 0L);
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
-                    SSLLogger.fine("algorithm = " + algorithm.toUpperCase() +
-                            ":" + tag[0] + "\ncountdown value = " +
-                            keyLimitCountdown);
+                    SSLLogger.fine("KeyLimit write side: algorithm = "
+                            + algorithm.toUpperCase() + ":" + tag[0] +
+                            "\ncountdown value = " + keyLimitCountdown);
                 }
                 if (keyLimitCountdown > 0) {
                     keyLimitEnabled = true;
--- a/src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLEngineInputRecord.java	Tue Jun 19 15:53:35 2018 -0700
@@ -34,6 +34,8 @@
 import javax.net.ssl.SSLHandshakeException;
 import javax.net.ssl.SSLProtocolException;
 import sun.security.ssl.SSLCipher.SSLReadCipher;
+import sun.security.ssl.KeyUpdate.KeyUpdateMessage;
+import sun.security.ssl.KeyUpdate.KeyUpdateRequest;
 
 /**
  * {@code InputRecord} implementation for {@code SSLEngine}.
@@ -331,6 +333,20 @@
             return plaintexts.toArray(new Plaintext[0]);
         }
 
+        // KeyLimit check during application data.
+        // atKeyLimit() inactive when limits not checked, tc set when limits
+        // are active.
+
+        if (readCipher.atKeyLimit()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.fine("KeyUpdate: triggered, read side.");
+            }
+
+            PostHandshakeContext p = new PostHandshakeContext(tc);
+            KeyUpdate.handshakeProducer.produce(p,
+                    new KeyUpdateMessage(p, KeyUpdateRequest.REQUESTED));
+        }
+
         return new Plaintext[] {
             new Plaintext(contentType,
                 majorVersion, minorVersion, -1, -1L, fragment)
--- a/src/java.base/share/classes/sun/security/ssl/SSLEngineOutputRecord.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLEngineOutputRecord.java	Tue Jun 19 15:53:35 2018 -0700
@@ -249,9 +249,11 @@
                 isFirstAppOutputRecord = false;
             }
 
+            // atKeyLimit() inactive when limits not checked, tc set when limits
+            // are active.
             if (writeCipher.atKeyLimit()) {
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
-                    SSLLogger.fine("KeyUpdate: triggered");
+                    SSLLogger.fine("KeyUpdate: triggered, write side.");
                 }
 
                 PostHandshakeContext p = new PostHandshakeContext(tc);
--- a/src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLSocketInputRecord.java	Tue Jun 19 15:53:35 2018 -0700
@@ -38,6 +38,8 @@
 import javax.net.ssl.SSLProtocolException;
 
 import sun.security.ssl.SSLCipher.SSLReadCipher;
+import sun.security.ssl.KeyUpdate.KeyUpdateMessage;
+import sun.security.ssl.KeyUpdate.KeyUpdateRequest;
 
 /**
  * {@code InputRecord} implementation for {@code SSLSocket}.
@@ -346,10 +348,23 @@
             return plaintexts.toArray(new Plaintext[0]);
         }
 
+        // KeyLimit check during application data.
+        // atKeyLimit() inactive when limits not checked, tc set when limits
+        // are active.
+
+        if (readCipher.atKeyLimit()) {
+            if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
+                SSLLogger.fine("KeyUpdate: triggered, read side.");
+            }
+
+            PostHandshakeContext p = new PostHandshakeContext(tc);
+            KeyUpdate.handshakeProducer.produce(p,
+                    new KeyUpdateMessage(p, KeyUpdateRequest.REQUESTED));
+        }
+
         return new Plaintext[] {
                 new Plaintext(contentType,
                     majorVersion, minorVersion, -1, -1L, fragment)
-                    // recordEpoch, recordSeq, plaintext);
             };
     }
 
--- a/src/java.base/share/classes/sun/security/ssl/SSLSocketOutputRecord.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/SSLSocketOutputRecord.java	Tue Jun 19 15:53:35 2018 -0700
@@ -305,9 +305,11 @@
 
             offset += fragLen;
 
+            // atKeyLimit() inactive when limits not checked, tc set when limits
+            // are active.
             if (writeCipher.atKeyLimit()) {
                 if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
-                    SSLLogger.fine("KeyUpdate: triggered");
+                    SSLLogger.fine("KeyUpdate: triggered, write side.");
                 }
 
                 PostHandshakeContext p = new PostHandshakeContext(tc);
--- a/src/java.base/share/classes/sun/security/ssl/TransportContext.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/src/java.base/share/classes/sun/security/ssl/TransportContext.java	Tue Jun 19 15:53:35 2018 -0700
@@ -605,6 +605,7 @@
     HandshakeStatus finishHandshake() {
         if (protocolVersion.useTLS13PlusSpec()) {
             outputRecord.tc = this;
+            inputRecord.tc = this;
             cipherSuite = handshakeContext.negotiatedCipherSuite;
             inputRecord.readCipher.baseSecret = handshakeContext.baseReadSecret;
             outputRecord.writeCipher.baseSecret = handshakeContext.baseWriteSecret;
--- a/test/jdk/java/net/httpclient/MockServer.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/test/jdk/java/net/httpclient/MockServer.java	Tue Jun 19 15:53:35 2018 -0700
@@ -183,7 +183,8 @@
             } catch (IOException |InterruptedException e1) {
                 cleanup();
             } catch (Throwable t) {
-                System.out.println("X: " + t);
+                System.out.println("Exception: " + t);
+                t.printStackTrace();
                 cleanup();
             }
         }
--- a/test/jdk/sun/security/ssl/SSLEngineImpl/SSLEngineKeyLimit.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/test/jdk/sun/security/ssl/SSLEngineImpl/SSLEngineKeyLimit.java	Tue Jun 19 15:53:35 2018 -0700
@@ -112,7 +112,7 @@
             System.setProperty("test.java.opts",
                     "-Dtest.src=" + System.getProperty("test.src") +
                             " -Dtest.jdk=" + System.getProperty("test.jdk") +
-                            " -Djavax.net.debug=ssl" +
+                            " -Djavax.net.debug=ssl,handshake" +
                             " -Djava.security.properties=" + f.getName());
 
             System.out.println("test.java.opts: " +
@@ -127,6 +127,8 @@
                     output.shouldNotContain("KeyUpdate: write key updated");
                     output.shouldNotContain("KeyUpdate: read key updated");
                 } else {
+                    output.shouldContain("KeyUpdate: triggered, read side");
+                    output.shouldContain("KeyUpdate: triggered, write side");
                     output.shouldContain("KeyUpdate: write key updated");
                     output.shouldContain("KeyUpdate: read key updated");
                 }
@@ -220,7 +222,7 @@
         }
         print("Write-side. ");
 
-        while (i++ < 120) {
+        while (i++ < 150) {
             while (sc) {
                 if (readdone) {
                     return;
@@ -378,7 +380,7 @@
             readdone = true;
             System.out.println(e.getMessage());
             e.printStackTrace();
-            print("Total data read = " + totalDataLen);
+            System.out.println("Total data read = " + totalDataLen);
         }
     }
 
@@ -442,7 +444,7 @@
     static class Client extends SSLEngineKeyLimit implements Runnable {
         Client() throws Exception {
             super();
-            eng = initContext().createSSLEngine("client", 80);
+            eng = initContext().createSSLEngine();
             eng.setUseClientMode(true);
         }
 
--- a/test/jdk/sun/security/ssl/SSLSocketImpl/SSLSocketKeyLimit.java	Tue Jun 19 09:05:57 2018 -0700
+++ b/test/jdk/sun/security/ssl/SSLSocketImpl/SSLSocketKeyLimit.java	Tue Jun 19 15:53:35 2018 -0700
@@ -63,16 +63,12 @@
 import sun.security.util.HexDumpEncoder;
 
 public class SSLSocketKeyLimit {
-
-    SSLSocket svr, c;
-    SSLServerSocketFactory ssf;
-    SSLServerSocket ss;
-    SSLSocketFactory sf;
-    InputStream in;
-    OutputStream out;
+    SSLSocket socket;
+    private InputStream in;
+    private OutputStream out;
 
     static boolean serverReady = false;
-    static int serverPort = 12345;
+    static int serverPort = 0;
 
     static String pathToStores = "../../../../javax/net/ssl/etc/";
     static String keyStoreFile = "keystore";
@@ -83,9 +79,7 @@
     int totalDataLen = 0;
     static boolean done = false;
 
-        SSLSocketKeyLimit() {
-        in = new ByteArrayInputStream(new byte[dataLen]);
-        out = new ByteArrayOutputStream();
+    SSLSocketKeyLimit() {
     }
 
     SSLContext initContext() throws Exception {
@@ -125,7 +119,7 @@
             System.setProperty("test.java.opts",
                     "-Dtest.src=" + System.getProperty("test.src") +
                             " -Dtest.jdk=" + System.getProperty("test.jdk") +
-                            " -Djavax.net.debug=ssl " +
+                            " -Djavax.net.debug=ssl,handshake " +
                             " -Djava.security.properties=" + f.getName());
 
             System.out.println("test.java.opts: " +
@@ -140,6 +134,8 @@
                     output.shouldNotContain("KeyUpdate: write key updated");
                     output.shouldNotContain("KeyUpdate: read key updated");
                 } else {
+                    output.shouldContain("KeyUpdate: triggered, read side");
+                    output.shouldContain("KeyUpdate: triggered, write side");
                     output.shouldContain("KeyUpdate: write key updated");
                     output.shouldContain("KeyUpdate: read key updated");
                 }
@@ -175,7 +171,6 @@
             Thread.sleep(100);
         }
         new Client().run();
-        ts.interrupt();
         ts.join(10000);  // 10sec
         System.exit(0);
     }
@@ -184,11 +179,12 @@
         int i = 0;
         in = s.getInputStream();
         out = s.getOutputStream();
-        System.out.print("Write-side writing... ");
         while (i++ < 150) {
             out.write(data, 0, dataLen);
+            System.out.print("W");
+            in.readNBytes(1);
+            System.out.print("R");
         }
-        out.flush();
         out.write(0x0D);
         out.flush();
 
@@ -196,20 +192,24 @@
         while (!done) {
             Thread.sleep(100);
         }
+        out.close();
+        in.close();
     }
 
 
     void read(SSLSocket s) throws Exception {
         byte[] buf = new byte[dataLen];
         int len;
-        int i = 0;
+        byte i = 0;
         try {
-            System.out.println("connected " + s.getSession().getCipherSuite());
+            System.out.println("Server: connected " + s.getSession().getCipherSuite());
             in = s.getInputStream();
             out = s.getOutputStream();
             while (true) {
                 len = in.read(buf, 0, dataLen);
-                System.out.print(".");
+                System.out.print("r");
+                out.write(i++);
+                System.out.print("w");
                 for (byte b: buf) {
                     if (b == 0x0A || b == 0x0D) {
                         continue;
@@ -219,10 +219,9 @@
                 }
 
                 if (len > 0 && buf[len-1] == 0x0D) {
-                    System.out.print("got end byte");
+                    System.out.println("got end byte");
                     break;
                 }
-                out.write(i++);
                 totalDataLen += len;
             }
         } catch (Exception e) {
@@ -230,12 +229,16 @@
             e.printStackTrace();
         } finally {
             // Tell write side that we are done reading
+            out.close();
+            in.close();
             done = true;
         }
         System.out.println("\nTotalDataLen = " + totalDataLen);
     }
 
     static class Server extends SSLSocketKeyLimit implements Runnable {
+        private SSLServerSocketFactory ssf;
+        private SSLServerSocket ss;
         Server() {
             super();
             try {
@@ -249,18 +252,17 @@
         }
 
         public void run() {
-
             try {
                 serverReady = true;
-                System.out.println("Server waiting... ");
-                svr = (SSLSocket) ss.accept();
+                System.out.println("Server waiting... port: " + serverPort);
+                socket = (SSLSocket) ss.accept();
                 if (serverwrite) {
-                    write(svr);
+                    write(socket);
                 } else {
-                    read(svr);
+                    read(socket);
                 }
 
-                svr.close();
+                socket.close();
             } catch (Exception e) {
                 System.out.println("server: " + e.getMessage());
                 e.printStackTrace();
@@ -271,6 +273,8 @@
 
 
     static class Client extends SSLSocketKeyLimit implements Runnable {
+        private SSLSocketFactory sf;
+
         Client() {
             super();
         }
@@ -278,15 +282,15 @@
         public void run() {
             try {
                 sf = initContext().getSocketFactory();
-                System.out.print("Client connecting... ");
-                c = (SSLSocket)sf.createSocket("localhost", serverPort);
-                System.out.println("connected. " + c.getSession().getCipherSuite());
+                System.out.println("Client: connecting... port: " + serverPort);
+                socket = (SSLSocket)sf.createSocket("localhost", serverPort);
+                System.out.println("Client: connected." + socket.getSession().getCipherSuite());
 
                 // Opposite of what the server does
                 if (!serverwrite) {
-                    write(c);
+                    write(socket);
                 } else {
-                    read(c);
+                    read(socket);
                 }
 
             } catch (Exception e) {