test/jdk/com/sun/jdi/JdwpAllowTest.java
changeset 54884 8a6093c186a6
parent 53743 0076e23f333f
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/jdk/com/sun/jdi/JdwpAllowTest.java	Wed May 15 11:06:33 2019 -0700
@@ -0,0 +1,260 @@
+/*
+ * Copyright (c) 2017, 2019, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+/*
+ * @test
+ * @summary Smoke test for JDWP hardening
+ * @library /test/lib
+ * @run driver JdwpAllowTest
+ */
+
+import java.io.IOException;
+
+import java.net.InetAddress;
+import java.net.Socket;
+import java.net.SocketException;
+
+import jdk.test.lib.Utils;
+import jdk.test.lib.apps.LingeredApp;
+
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+
+public class JdwpAllowTest {
+
+    public static int handshake(int port) throws IOException {
+        // Connect to the debuggee and handshake
+        int res = -1;
+        Socket s = null;
+        try {
+            s = new Socket(localAddr, port);
+            s.getOutputStream().write("JDWP-Handshake".getBytes("UTF-8"));
+            byte[] buffer = new byte[24];
+            res = s.getInputStream().read(buffer);
+        }
+        catch (SocketException ex) {
+            ex.printStackTrace();
+            // pass
+        } finally {
+            if (s != null) {
+                s.close();
+            }
+        }
+        return res;
+    }
+
+    public static ArrayList<String> prepareCmd(String allowOpt) {
+         ArrayList<String> cmd = new ArrayList<>();
+
+         String jdwpArgs = "-agentlib:jdwp=transport=dt_socket,server=y," +
+                           "suspend=n,address=*:0"
+                            + (allowOpt == null ? "" : ",allow=" + allowOpt);
+         cmd.add(jdwpArgs);
+         return cmd;
+    }
+
+    private static Pattern listenRegexp = Pattern.compile("Listening for transport \\b(.+)\\b at address: \\b(\\d+)\\b");
+    private static int detectPort(LingeredApp app) {
+        long maxWaitTime = System.currentTimeMillis()
+                + Utils.adjustTimeout(10000);  // 10 seconds adjusted for TIMEOUT_FACTOR
+        while (true) {
+            String s = app.getProcessStdout();
+            Matcher m = listenRegexp.matcher(s);
+            if (m.find()) {
+                // m.group(1) is transport, m.group(2) is port
+                return Integer.parseInt(m.group(2));
+            }
+            if (System.currentTimeMillis() > maxWaitTime) {
+                throw new RuntimeException("Could not detect port from '" + s + "' (timeout)");
+            }
+            try {
+                if (app.getProcess().waitFor(500, TimeUnit.MILLISECONDS)) {
+                    throw new RuntimeException("Could not detect port from '" + s + "' (debuggee is terminated)");
+                }
+            } catch (InterruptedException e) {
+                // ignore
+            }
+        }
+    }
+
+    public static void positiveTest(String testName, String allowOpt)
+        throws InterruptedException, IOException {
+        System.err.println("\nStarting " + testName);
+        ArrayList<String> cmd = prepareCmd(allowOpt);
+
+        LingeredApp a = LingeredApp.startApp(cmd);
+        int res;
+        try {
+            res = handshake(detectPort(a));
+        } finally {
+            a.stopApp();
+        }
+        if (res < 0) {
+            throw new RuntimeException(testName + " FAILED");
+        }
+        System.err.println(testName + " PASSED");
+    }
+
+    public static void negativeTest(String testName, String allowOpt)
+        throws InterruptedException, IOException {
+        System.err.println("\nStarting " + testName);
+        ArrayList<String> cmd = prepareCmd(allowOpt);
+
+        LingeredApp a = LingeredApp.startApp(cmd);
+        int res;
+        try {
+            res = handshake(detectPort(a));
+        } finally {
+            a.stopApp();
+        }
+        if (res > 0) {
+            System.err.println(testName + ": res=" + res);
+            throw new RuntimeException(testName + " FAILED");
+        }
+        System.err.println(testName + ": returned a negative code as expected: " + res);
+        System.err.println(testName + " PASSED");
+    }
+
+    public static void badAllowOptionTest(String testName, String allowOpt)
+        throws InterruptedException, IOException {
+        System.err.println("\nStarting " + testName);
+        ArrayList<String> cmd = prepareCmd(allowOpt);
+
+        LingeredApp a;
+        try {
+            a = LingeredApp.startApp(cmd);
+        } catch (IOException ex) {
+            System.err.println(testName + ": caught expected IOException");
+            System.err.println(testName + " PASSED");
+            return;
+        }
+        // LingeredApp.startApp is expected to fail, but if not, terminate the app
+        a.stopApp();
+        throw new RuntimeException(testName + " FAILED");
+    }
+
+    /*
+     * Generate allow address by changing random bit in the local address
+     * and calculate 2 masks (prefix length) - one is matches original local address
+     * and another doesn't.
+     */
+    private static class MaskTest {
+        public final String localAddress;
+        public final String allowAddress;
+        public final int prefixLengthGood;
+        public final int prefixLengthBad;
+
+        public MaskTest(InetAddress addr) throws Exception {
+            localAddress = addr.getHostAddress();
+            byte[] bytes = addr.getAddress();
+            Random r = new Random();
+            // prefix length must be >= 1, so bitToChange must be >= 2
+            int bitToChange = r.nextInt(bytes.length * 8 - 3) + 2;
+            setBit(bytes, bitToChange, !getBit(bytes, bitToChange));
+            // clear rest of the bits for mask address
+            for (int i = bitToChange + 1; i < bytes.length * 8; i++) {
+                setBit(bytes, i, false);
+            }
+            allowAddress = InetAddress.getByAddress(bytes).getHostAddress();
+
+            prefixLengthBad = bitToChange;
+            prefixLengthGood = bitToChange - 1;
+        }
+
+        private static boolean getBit(byte[] bytes, int pos) {
+            return (bytes[pos / 8] & (1 << (7 - (pos % 8)))) != 0;
+        }
+
+        private static void setBit(byte[] bytes, int pos, boolean value) {
+            byte byteValue = (byte)(1 << (7 - (pos % 8)));
+            if (value) {
+                bytes[pos / 8] = (byte)(bytes[pos / 8] | byteValue);
+            } else {
+                bytes[pos / 8] &= (~byteValue);
+            }
+        }
+    }
+
+    private static String localAddr;
+    private static List<MaskTest> maskTests = new LinkedList<>();
+
+    private static void init() throws Exception {
+        InetAddress addrs[] = InetAddress.getAllByName("localhost");
+        if (addrs.length == 0) {
+            throw new RuntimeException("No addresses is returned for 'localhost'");
+        }
+        localAddr = addrs[0].getHostAddress();
+        System.err.println("localhost address: " + localAddr);
+
+        for (int i =  0; i < addrs.length; i++) {
+            maskTests.add(new MaskTest(addrs[i]));
+        }
+    }
+
+    public static void main(String[] args) throws Exception {
+        init();
+
+        // No allow option is the same as the allow option ',allow=*' is passed
+        positiveTest("DefaultTest", null);
+
+        // Explicit permission for connections from everywhere
+        positiveTest("ExplicitDefaultTest", "*");
+
+        positiveTest("AllowTest", localAddr);
+
+        positiveTest("MultiAllowTest", localAddr + "+10.0.0.0/8+172.16.0.0/12+192.168.0.0/24");
+
+        // Bad allow address
+        negativeTest("DenyTest", "0.0.0.0");
+
+        // Wrong separator ';' is used for allow option
+        badAllowOptionTest("MultiDenyTest", localAddr + ";192.168.0.0/24");
+
+        // Empty allow option
+        badAllowOptionTest("EmptyAllowOptionTest", "");
+
+        // Bad mix of allow option '*' with address value
+        badAllowOptionTest("ExplicitMultiDefault1Test", "*+" + localAddr);
+
+        // Bad mix of allow address value with '*'
+        badAllowOptionTest("ExplicitMultiDefault2Test", localAddr + "+*");
+
+        for (MaskTest test: maskTests) {
+            // override localAddr (to connect to required IPv4 or IPv6 address)
+            localAddr = test.localAddress;
+            positiveTest("PositiveMaskTest(" + test.localAddress + ")",
+                         test.allowAddress + "/" + test.prefixLengthGood);
+            positiveTest("NegativeMaskTest(" + test.localAddress + ")",
+                         test.allowAddress + "/" + test.prefixLengthBad);
+        }
+
+        System.err.println("\nTest PASSED");
+    }
+
+}