src/java.xml/share/classes/jdk/xml/internal/SecuritySupport.java
changeset 47312 d4f959806fe9
parent 47216 71c04702a3d5
child 48944 25aa8b9f1dae
--- a/src/java.xml/share/classes/jdk/xml/internal/SecuritySupport.java	Wed Oct 04 10:44:21 2017 -0700
+++ b/src/java.xml/share/classes/jdk/xml/internal/SecuritySupport.java	Wed Oct 04 10:54:18 2017 -0700
@@ -29,7 +29,9 @@
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.io.InputStream;
+import java.net.URL;
 import java.security.AccessController;
+import java.security.CodeSource;
 import java.security.PrivilegedAction;
 import java.security.PrivilegedActionException;
 import java.security.PrivilegedExceptionAction;
@@ -82,7 +84,7 @@
     public static String getSystemProperty(final String propName) {
         return
         AccessController.doPrivileged(
-                (PrivilegedAction<String>) () -> (String)System.getProperty(propName));
+                (PrivilegedAction<String>) () -> System.getProperty(propName));
     }
 
     /**
@@ -220,6 +222,12 @@
                 -> f.exists()));
     }
 
+    /**
+     * Creates and returns a new FileInputStream from a file.
+     * @param file the specified file
+     * @return the FileInputStream
+     * @throws FileNotFoundException if the file is not found
+     */
     public static FileInputStream getFileInputStream(final File file)
             throws FileNotFoundException {
         try {
@@ -231,6 +239,16 @@
     }
 
     /**
+     * Returns the resource as a stream.
+     * @param name the resource name
+     * @return the resource stream
+     */
+    public static InputStream getResourceAsStream(final String name) {
+        return AccessController.doPrivileged((PrivilegedAction<InputStream>) () ->
+                SecuritySupport.class.getResourceAsStream("/"+name));
+    }
+
+    /**
      * Gets a resource bundle using the specified base name, the default locale, and the caller's class loader.
      * @param bundle the base name of the resource bundle, a fully qualified class name
      * @return a resource bundle for the given base name and the default locale
@@ -259,4 +277,179 @@
             }
         });
     }
+
+    /**
+     * Checks whether the file exists.
+     * @param f the specified file
+     * @return true if the file exists, false otherwise
+     */
+    public static boolean doesFileExist(final File f) {
+        return (AccessController.doPrivileged((PrivilegedAction<Boolean>) () -> f.exists()));
+    }
+
+    /**
+     * Checks the LastModified attribute of a file.
+     * @param f the specified file
+     * @return the LastModified attribute
+     */
+    static long getLastModified(final File f) {
+        return (AccessController.doPrivileged((PrivilegedAction<Long>) () -> f.lastModified()));
+    }
+
+    /**
+     * Strip off path from an URI
+     *
+     * @param uri an URI with full path
+     * @return the file name only
+     */
+    public static String sanitizePath(String uri) {
+        if (uri == null) {
+            return "";
+        }
+        int i = uri.lastIndexOf("/");
+        if (i > 0) {
+            return uri.substring(i+1, uri.length());
+        }
+        return "";
+    }
+
+    /**
+     * Check the protocol used in the systemId against allowed protocols
+     *
+     * @param systemId the Id of the URI
+     * @param allowedProtocols a list of allowed protocols separated by comma
+     * @param accessAny keyword to indicate allowing any protocol
+     * @return the name of the protocol if rejected, null otherwise
+     */
+    public static String checkAccess(String systemId, String allowedProtocols,
+            String accessAny) throws IOException {
+        if (systemId == null || (allowedProtocols != null &&
+                allowedProtocols.equalsIgnoreCase(accessAny))) {
+            return null;
+        }
+
+        String protocol;
+        if (!systemId.contains(":")) {
+            protocol = "file";
+        } else {
+            URL url = new URL(systemId);
+            protocol = url.getProtocol();
+            if (protocol.equalsIgnoreCase("jar")) {
+                String path = url.getPath();
+                protocol = path.substring(0, path.indexOf(":"));
+            } else if (protocol.equalsIgnoreCase("jrt")) {
+                // if the systemId is "jrt" then allow access if "file" allowed
+                protocol = "file";
+            }
+        }
+
+        if (isProtocolAllowed(protocol, allowedProtocols)) {
+            //access allowed
+            return null;
+        } else {
+            return protocol;
+        }
+    }
+
+    /**
+     * Check if the protocol is in the allowed list of protocols. The check
+     * is case-insensitive while ignoring whitespaces.
+     *
+     * @param protocol a protocol
+     * @param allowedProtocols a list of allowed protocols
+     * @return true if the protocol is in the list
+     */
+    private static boolean isProtocolAllowed(String protocol, String allowedProtocols) {
+         if (allowedProtocols == null) {
+             return false;
+         }
+         String temp[] = allowedProtocols.split(",");
+         for (String t : temp) {
+             t = t.trim();
+             if (t.equalsIgnoreCase(protocol)) {
+                 return true;
+             }
+         }
+         return false;
+     }
+
+    public static ClassLoader getContextClassLoader() {
+        return AccessController.doPrivileged((PrivilegedAction<ClassLoader>) () -> {
+            ClassLoader cl = Thread.currentThread().getContextClassLoader();
+            if (cl == null)
+                cl = ClassLoader.getSystemClassLoader();
+            return cl;
+        });
+    }
+
+
+    public static ClassLoader getSystemClassLoader() {
+        return AccessController.doPrivileged((PrivilegedAction<ClassLoader>) () -> {
+            ClassLoader cl = null;
+            try {
+                cl = ClassLoader.getSystemClassLoader();
+            } catch (SecurityException ex) {
+            }
+            return cl;
+        });
+    }
+
+    public static ClassLoader getParentClassLoader(final ClassLoader cl) {
+        return AccessController.doPrivileged((PrivilegedAction<ClassLoader>) () -> {
+            ClassLoader parent = null;
+            try {
+                parent = cl.getParent();
+            } catch (SecurityException ex) {
+            }
+
+            // eliminate loops in case of the boot
+            // ClassLoader returning itself as a parent
+            return (parent == cl) ? null : parent;
+        });
+    }
+
+
+    // Used for debugging purposes
+    public static String getClassSource(Class<?> cls) {
+        return AccessController.doPrivileged((PrivilegedAction<String>) () -> {
+            CodeSource cs = cls.getProtectionDomain().getCodeSource();
+            if (cs != null) {
+                URL loc = cs.getLocation();
+                return loc != null ? loc.toString() : "(no location)";
+            } else {
+                return "(no code source)";
+            }
+        });
+    }
+
+    // ----------------  For SAX ----------------------
+    /**
+     * Returns the current thread's context class loader, or the system class loader
+     * if the context class loader is null.
+     * @return the current thread's context class loader, or the system class loader
+     * @throws SecurityException
+     */
+    public static ClassLoader getClassLoader() throws SecurityException{
+        return AccessController.doPrivileged((PrivilegedAction<ClassLoader>)() -> {
+            ClassLoader cl = Thread.currentThread().getContextClassLoader();
+            if (cl == null) {
+                cl = ClassLoader.getSystemClassLoader();
+            }
+
+            return cl;
+        });
+    }
+
+    public static InputStream getResourceAsStream(final ClassLoader cl, final String name)
+    {
+        return AccessController.doPrivileged((PrivilegedAction<InputStream>) () -> {
+            InputStream ris;
+            if (cl == null) {
+                ris = SecuritySupport.class.getResourceAsStream(name);
+            } else {
+                ris = cl.getResourceAsStream(name);
+            }
+            return ris;
+        });
+    }
 }