8145838: JShell: restrict RemoteAgent connection socket to localhost
authorjlahoda
Fri, 11 Nov 2016 12:54:47 +0100
changeset 41994 e43f670394ca
parent 41993 c8260c3ae93b
child 41995 1ac75bf2dc3a
8145838: JShell: restrict RemoteAgent connection socket to localhost Summary: Also reviewed by Chris Ries Reviewed-by: rfield
langtools/src/jdk.jshell/share/classes/jdk/jshell/execution/JdiDefaultExecutionControl.java
langtools/src/jdk.jshell/share/classes/jdk/jshell/execution/RemoteExecutionControl.java
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/execution/JdiDefaultExecutionControl.java	Fri Nov 11 05:56:09 2016 +0000
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/execution/JdiDefaultExecutionControl.java	Fri Nov 11 12:54:47 2016 +0100
@@ -24,18 +24,23 @@
  */
 package jdk.jshell.execution;
 
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.ObjectInput;
 import java.io.ObjectOutput;
 import java.io.OutputStream;
+import java.net.InetAddress;
 import java.net.ServerSocket;
 import java.net.Socket;
+import java.security.SecureRandom;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Consumer;
+
 import com.sun.jdi.BooleanValue;
 import com.sun.jdi.ClassNotLoadedException;
 import com.sun.jdi.Field;
@@ -111,7 +116,7 @@
      */
     private static ExecutionControl create(ExecutionEnv env,
             boolean isLaunch, String host) throws IOException {
-        try (final ServerSocket listener = new ServerSocket(0)) {
+        try (final ServerSocket listener = new ServerSocket(0, 1, InetAddress.getLoopbackAddress())) {
             // timeout after 60 seconds
             listener.setSoTimeout(60000);
             int port = listener.getLocalPort();
@@ -122,6 +127,14 @@
             VirtualMachine vm = jdii.vm();
             Process process = jdii.process();
 
+            OutputStream processOut = process.getOutputStream();
+            SecureRandom rng = new SecureRandom();
+            byte[] randomBytes = new byte[VERIFY_HASH_LEN];
+
+            rng.nextBytes(randomBytes);
+            processOut.write(randomBytes);
+            processOut.flush();
+
             List<Consumer<String>> deathListeners = new ArrayList<>();
             deathListeners.add(s -> env.closeDown());
             Util.detectJdiExitEvent(vm, s -> {
@@ -130,6 +143,8 @@
                 }
             });
 
+            ByteArrayOutputStream receivedRandomBytes = new ByteArrayOutputStream();
+
             // Set-up the commands/reslts on the socket.  Piggy-back snippet
             // output.
             Socket socket = listener.accept();
@@ -138,11 +153,35 @@
             Map<String, OutputStream> outputs = new HashMap<>();
             outputs.put("out", env.userOut());
             outputs.put("err", env.userErr());
+            outputs.put("echo", new OutputStream() {
+                @Override public void write(int b) throws IOException {
+                    synchronized (receivedRandomBytes) {
+                        receivedRandomBytes.write(b);
+                        receivedRandomBytes.notify();
+                    }
+                }
+            });
             Map<String, InputStream> input = new HashMap<>();
             input.put("in", env.userIn());
-            return remoteInputOutput(socket.getInputStream(), out, outputs, input, (objIn, objOut) -> new JdiDefaultExecutionControl(objOut, objIn, vm, process, deathListeners));
+            return remoteInputOutput(socket.getInputStream(), out, outputs, input, (objIn, objOut) -> {
+                synchronized (receivedRandomBytes) {
+                    while (receivedRandomBytes.size() < randomBytes.length) {
+                        try {
+                            receivedRandomBytes.wait();
+                        } catch (InterruptedException ex) {
+                            //ignore
+                        }
+                    }
+                    if (!Arrays.equals(receivedRandomBytes.toByteArray(), randomBytes)) {
+                        throw new IllegalStateException("Invalid connection!");
+                    }
+                }
+                return new JdiDefaultExecutionControl(objOut, objIn, vm, process, deathListeners);
+            });
         }
     }
+    //where:
+        private static final int VERIFY_HASH_LEN = 20;
 
     /**
      * Create an instance.
--- a/langtools/src/jdk.jshell/share/classes/jdk/jshell/execution/RemoteExecutionControl.java	Fri Nov 11 05:56:09 2016 +0000
+++ b/langtools/src/jdk.jshell/share/classes/jdk/jshell/execution/RemoteExecutionControl.java	Fri Nov 11 12:54:47 2016 +0100
@@ -24,6 +24,7 @@
  */
 package jdk.jshell.execution;
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.PrintStream;
@@ -57,6 +58,7 @@
      * @throws Exception any unexpected exception
      */
     public static void main(String[] args) throws Exception {
+        InputStream fd0 = System.in;
         String loopBack = null;
         Socket socket = new Socket(loopBack, Integer.parseInt(args[0]));
         InputStream inStream = socket.getInputStream();
@@ -64,6 +66,25 @@
         Map<String, Consumer<OutputStream>> outputs = new HashMap<>();
         outputs.put("out", st -> System.setOut(new PrintStream(st, true)));
         outputs.put("err", st -> System.setErr(new PrintStream(st, true)));
+        outputs.put("echo", st -> {
+            new Thread(() -> {
+                try {
+                    int read;
+
+                    while ((read = fd0.read()) != (-1)) {
+                        st.write(read);
+                    }
+                } catch (IOException ex) {
+                    ex.printStackTrace();
+                } finally {
+                    try {
+                        st.close();
+                    } catch (IOException ex) {
+                        ex.printStackTrace();
+                    }
+                }
+            }).start();
+        });
         Map<String, Consumer<InputStream>> input = new HashMap<>();
         input.put("in", st -> System.setIn(st));
         forwardExecutionControlAndIO(new RemoteExecutionControl(), inStream, outStream, outputs, input);