src/java.rmi/share/classes/sun/rmi/transport/StreamRemoteCall.java
changeset 58635 06d7236d6ef6
parent 47216 71c04702a3d5
--- a/src/java.rmi/share/classes/sun/rmi/transport/StreamRemoteCall.java	Thu Jul 18 07:25:17 2019 -0700
+++ b/src/java.rmi/share/classes/sun/rmi/transport/StreamRemoteCall.java	Thu Jan 17 10:44:17 2019 -0500
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -29,6 +29,7 @@
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.ObjectInput;
+import java.io.ObjectInputFilter;
 import java.io.ObjectOutput;
 import java.io.StreamCorruptedException;
 import java.rmi.RemoteException;
@@ -36,6 +37,9 @@
 import java.rmi.UnmarshalException;
 import java.rmi.server.ObjID;
 import java.rmi.server.RemoteCall;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+
 import sun.rmi.runtime.Log;
 import sun.rmi.server.UnicastRef;
 import sun.rmi.transport.tcp.TCPEndpoint;
@@ -50,6 +54,7 @@
     private ConnectionInputStream in = null;
     private ConnectionOutputStream out = null;
     private Connection conn;
+    private ObjectInputFilter filter = null;
     private boolean resultStarted = false;
     private Exception serverException = null;
 
@@ -123,6 +128,13 @@
         }
     }
 
+    public void setObjectInputFilter(ObjectInputFilter filter) {
+        if (in != null) {
+            throw new IllegalStateException("set filter must occur before calling getInputStream");
+        }
+        this.filter = filter;
+    }
+
     /**
      * Get the InputStream the stub/skeleton should get results/arguments
      * from.
@@ -132,6 +144,12 @@
             Transport.transportLog.log(Log.VERBOSE, "getting input stream");
 
             in = new ConnectionInputStream(conn.getInputStream());
+            if (filter != null) {
+                AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
+                    in.setObjectInputFilter(filter);
+                    return null;
+                });
+            }
         }
         return in;
     }
@@ -251,6 +269,7 @@
             try {
                 ex = in.readObject();
             } catch (Exception e) {
+                discardPendingRefs();
                 throw new UnmarshalException("Error unmarshaling return", e);
             }
 
@@ -259,6 +278,7 @@
             if (ex instanceof Exception) {
                 exceptionReceivedFromServer((Exception) ex);
             } else {
+                discardPendingRefs();
                 throw new UnmarshalException("Return type not Exception");
             }
             // Exception is thrown before fallthrough can occur