8218863: Better endpoint checks
authorxuelei
Wed, 27 Feb 2019 13:58:04 -0800
changeset 55706 e29d7fea0e4d
parent 55705 e0f5ad90737c
child 55707 b8a12f53226e
8218863: Better endpoint checks Reviewed-by: ahgross, jnimeh, mullan, rhalade
src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java
src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java
--- a/src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java	Thu Feb 07 08:47:10 2019 -0500
+++ b/src/java.base/share/classes/sun/security/ssl/SSLContextImpl.java	Wed Feb 27 13:58:04 2019 -0800
@@ -1480,8 +1480,9 @@
         checkAdditionalTrust(chain, authType, engine, false);
     }
 
-    private void checkAdditionalTrust(X509Certificate[] chain, String authType,
-                Socket socket, boolean isClient) throws CertificateException {
+    private void checkAdditionalTrust(X509Certificate[] chain,
+            String authType, Socket socket,
+            boolean checkClientTrusted) throws CertificateException {
         if (socket != null && socket.isConnected() &&
                                     socket instanceof SSLSocket) {
 
@@ -1495,9 +1496,8 @@
             String identityAlg = sslSocket.getSSLParameters().
                                         getEndpointIdentificationAlgorithm();
             if (identityAlg != null && !identityAlg.isEmpty()) {
-                String hostname = session.getPeerHost();
-                X509TrustManagerImpl.checkIdentity(
-                                    hostname, chain[0], identityAlg);
+                X509TrustManagerImpl.checkIdentity(session, chain,
+                                    identityAlg, checkClientTrusted);
             }
 
             // try the best to check the algorithm constraints
@@ -1519,12 +1519,13 @@
                 constraints = new SSLAlgorithmConstraints(sslSocket, true);
             }
 
-            checkAlgorithmConstraints(chain, constraints, isClient);
+            checkAlgorithmConstraints(chain, constraints, checkClientTrusted);
         }
     }
 
-    private void checkAdditionalTrust(X509Certificate[] chain, String authType,
-            SSLEngine engine, boolean isClient) throws CertificateException {
+    private void checkAdditionalTrust(X509Certificate[] chain,
+            String authType, SSLEngine engine,
+            boolean checkClientTrusted) throws CertificateException {
         if (engine != null) {
             SSLSession session = engine.getHandshakeSession();
             if (session == null) {
@@ -1535,9 +1536,8 @@
             String identityAlg = engine.getSSLParameters().
                                         getEndpointIdentificationAlgorithm();
             if (identityAlg != null && !identityAlg.isEmpty()) {
-                String hostname = session.getPeerHost();
-                X509TrustManagerImpl.checkIdentity(
-                                    hostname, chain[0], identityAlg);
+                X509TrustManagerImpl.checkIdentity(session, chain,
+                                    identityAlg, checkClientTrusted);
             }
 
             // try the best to check the algorithm constraints
@@ -1559,13 +1559,13 @@
                 constraints = new SSLAlgorithmConstraints(engine, true);
             }
 
-            checkAlgorithmConstraints(chain, constraints, isClient);
+            checkAlgorithmConstraints(chain, constraints, checkClientTrusted);
         }
     }
 
     private void checkAlgorithmConstraints(X509Certificate[] chain,
             AlgorithmConstraints constraints,
-            boolean isClient) throws CertificateException {
+            boolean checkClientTrusted) throws CertificateException {
         try {
             // Does the certificate chain end with a trusted certificate?
             int checkedLength = chain.length - 1;
@@ -1584,7 +1584,7 @@
             if (checkedLength >= 0) {
                 AlgorithmChecker checker =
                     new AlgorithmChecker(constraints, null,
-                            (isClient ? Validator.VAR_TLS_CLIENT :
+                            (checkClientTrusted ? Validator.VAR_TLS_CLIENT :
                                         Validator.VAR_TLS_SERVER));
                 checker.init(false);
                 for (int i = checkedLength; i >= 0; i--) {
--- a/src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java	Thu Feb 07 08:47:10 2019 -0500
+++ b/src/java.base/share/classes/sun/security/ssl/X509TrustManagerImpl.java	Wed Feb 27 13:58:04 2019 -0800
@@ -145,7 +145,7 @@
     }
 
     private Validator checkTrustedInit(X509Certificate[] chain,
-                                        String authType, boolean isClient) {
+            String authType, boolean checkClientTrusted) {
         if (chain == null || chain.length == 0) {
             throw new IllegalArgumentException(
                 "null or zero-length certificate chain");
@@ -157,7 +157,7 @@
         }
 
         Validator v = null;
-        if (isClient) {
+        if (checkClientTrusted) {
             v = clientValidator;
             if (v == null) {
                 validatorLock.lock();
@@ -192,9 +192,10 @@
         return v;
     }
 
-    private void checkTrusted(X509Certificate[] chain, String authType,
-                Socket socket, boolean isClient) throws CertificateException {
-        Validator v = checkTrustedInit(chain, authType, isClient);
+    private void checkTrusted(X509Certificate[] chain,
+            String authType, Socket socket,
+            boolean checkClientTrusted) throws CertificateException {
+        Validator v = checkTrustedInit(chain, authType, checkClientTrusted);
 
         X509Certificate[] trustedChain = null;
         if ((socket != null) && socket.isConnected() &&
@@ -223,28 +224,23 @@
 
             // Grab any stapled OCSP responses for use in validation
             List<byte[]> responseList = Collections.emptyList();
-            if (!isClient && isExtSession) {
+            if (!checkClientTrusted && isExtSession) {
                 responseList =
                         ((ExtendedSSLSession)session).getStatusResponses();
             }
             trustedChain = v.validate(chain, null, responseList,
-                    constraints, isClient ? null : authType);
-
-            // check if EE certificate chains to a public root CA (as
-            // pre-installed in cacerts)
-            boolean chainsToPublicCA = AnchorCertificates.contains(
-                    trustedChain[trustedChain.length-1]);
+                    constraints, checkClientTrusted ? null : authType);
 
             // check endpoint identity
             String identityAlg = sslSocket.getSSLParameters().
                     getEndpointIdentificationAlgorithm();
             if (identityAlg != null && !identityAlg.isEmpty()) {
-                checkIdentity(session, trustedChain[0], identityAlg, isClient,
-                        getRequestedServerNames(socket), chainsToPublicCA);
+                checkIdentity(session,
+                        trustedChain, identityAlg, checkClientTrusted);
             }
         } else {
             trustedChain = v.validate(chain, null, Collections.emptyList(),
-                    null, isClient ? null : authType);
+                    null, checkClientTrusted ? null : authType);
         }
 
         if (SSLLogger.isOn && SSLLogger.isOn("ssl,trustmanager")) {
@@ -253,9 +249,10 @@
         }
     }
 
-    private void checkTrusted(X509Certificate[] chain, String authType,
-            SSLEngine engine, boolean isClient) throws CertificateException {
-        Validator v = checkTrustedInit(chain, authType, isClient);
+    private void checkTrusted(X509Certificate[] chain,
+            String authType, SSLEngine engine,
+            boolean checkClientTrusted) throws CertificateException {
+        Validator v = checkTrustedInit(chain, authType, checkClientTrusted);
 
         X509Certificate[] trustedChain = null;
         if (engine != null) {
@@ -281,28 +278,23 @@
 
             // Grab any stapled OCSP responses for use in validation
             List<byte[]> responseList = Collections.emptyList();
-            if (!isClient && isExtSession) {
+            if (!checkClientTrusted && isExtSession) {
                 responseList =
                         ((ExtendedSSLSession)session).getStatusResponses();
             }
             trustedChain = v.validate(chain, null, responseList,
-                    constraints, isClient ? null : authType);
-
-            // check if EE certificate chains to a public root CA (as
-            // pre-installed in cacerts)
-            boolean chainsToPublicCA = AnchorCertificates.contains(
-                    trustedChain[trustedChain.length-1]);
+                    constraints, checkClientTrusted ? null : authType);
 
             // check endpoint identity
             String identityAlg = engine.getSSLParameters().
                     getEndpointIdentificationAlgorithm();
             if (identityAlg != null && !identityAlg.isEmpty()) {
-                checkIdentity(session, trustedChain[0], identityAlg, isClient,
-                        getRequestedServerNames(engine), chainsToPublicCA);
+                checkIdentity(session, trustedChain,
+                        identityAlg, checkClientTrusted);
             }
         } else {
             trustedChain = v.validate(chain, null, Collections.emptyList(),
-                    null, isClient ? null : authType);
+                    null, checkClientTrusted ? null : authType);
         }
 
         if (SSLLogger.isOn && SSLLogger.isOn("ssl,trustmanager")) {
@@ -360,14 +352,8 @@
     static List<SNIServerName> getRequestedServerNames(Socket socket) {
         if (socket != null && socket.isConnected() &&
                                         socket instanceof SSLSocket) {
-
-            SSLSocket sslSocket = (SSLSocket)socket;
-            SSLSession session = sslSocket.getHandshakeSession();
-
-            if (session != null && (session instanceof ExtendedSSLSession)) {
-                ExtendedSSLSession extSession = (ExtendedSSLSession)session;
-                return extSession.getRequestedServerNames();
-            }
+            return getRequestedServerNames(
+                    ((SSLSocket)socket).getHandshakeSession());
         }
 
         return Collections.<SNIServerName>emptyList();
@@ -376,12 +362,16 @@
     // Also used by X509KeyManagerImpl
     static List<SNIServerName> getRequestedServerNames(SSLEngine engine) {
         if (engine != null) {
-            SSLSession session = engine.getHandshakeSession();
+            return getRequestedServerNames(engine.getHandshakeSession());
+        }
 
-            if (session != null && (session instanceof ExtendedSSLSession)) {
-                ExtendedSSLSession extSession = (ExtendedSSLSession)session;
-                return extSession.getRequestedServerNames();
-            }
+        return Collections.<SNIServerName>emptyList();
+    }
+
+    private static List<SNIServerName> getRequestedServerNames(
+            SSLSession session) {
+        if (session != null && (session instanceof ExtendedSSLSession)) {
+            return ((ExtendedSSLSession)session).getRequestedServerNames();
         }
 
         return Collections.<SNIServerName>emptyList();
@@ -402,23 +392,28 @@
      * the identity checking aginst the server_name extension if present, and
      * may failove to peer host checking.
      */
-    private static void checkIdentity(SSLSession session,
-            X509Certificate cert,
+    static void checkIdentity(SSLSession session,
+            X509Certificate[] trustedChain,
             String algorithm,
-            boolean isClient,
-            List<SNIServerName> sniNames,
-            boolean chainsToPublicCA) throws CertificateException {
+            boolean checkClientTrusted) throws CertificateException {
+
+        // check if EE certificate chains to a public root CA (as
+        // pre-installed in cacerts)
+        boolean chainsToPublicCA = AnchorCertificates.contains(
+                trustedChain[trustedChain.length - 1]);
 
         boolean identifiable = false;
         String peerHost = session.getPeerHost();
-        if (isClient) {
-            String hostname = getHostNameInSNI(sniNames);
-            if (hostname != null) {
+        if (!checkClientTrusted) {
+            List<SNIServerName> sniNames = getRequestedServerNames(session);
+            String sniHostName = getHostNameInSNI(sniNames);
+            if (sniHostName != null) {
                 try {
-                    checkIdentity(hostname, cert, algorithm, chainsToPublicCA);
+                    checkIdentity(sniHostName,
+                            trustedChain[0], algorithm, chainsToPublicCA);
                     identifiable = true;
                 } catch (CertificateException ce) {
-                    if (hostname.equalsIgnoreCase(peerHost)) {
+                    if (sniHostName.equalsIgnoreCase(peerHost)) {
                         throw ce;
                     }
 
@@ -428,7 +423,8 @@
         }
 
         if (!identifiable) {
-            checkIdentity(peerHost, cert, algorithm, chainsToPublicCA);
+            checkIdentity(peerHost,
+                    trustedChain[0], algorithm, chainsToPublicCA);
         }
     }