8027318: Lambda Metafactory: generate serialization-hostile read/writeObject methods for non-serializable lambdas
authorbriangoetz
Tue, 29 Oct 2013 12:31:27 -0400
changeset 21417 58b329cea7f2
parent 21416 7d9198dd107c
child 21418 88cb3367643e
8027318: Lambda Metafactory: generate serialization-hostile read/writeObject methods for non-serializable lambdas Reviewed-by: rfield, psandoz
jdk/src/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java
jdk/test/java/util/stream/test/org/openjdk/tests/java/lang/invoke/SerializedLambdaTest.java
--- a/jdk/src/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java	Wed Oct 30 13:51:07 2013 -0700
+++ b/jdk/src/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java	Tue Oct 29 12:31:27 2013 -0400
@@ -30,6 +30,7 @@
 import sun.security.action.GetPropertyAction;
 
 import java.io.FilePermission;
+import java.io.Serializable;
 import java.lang.reflect.Constructor;
 import java.security.AccessController;
 import java.security.PrivilegedAction;
@@ -56,8 +57,13 @@
 
     //Serialization support
     private static final String NAME_SERIALIZED_LAMBDA = "java/lang/invoke/SerializedLambda";
+    private static final String NAME_NOT_SERIALIZABLE_EXCEPTION = "java/io/NotSerializableException";
     private static final String DESCR_METHOD_WRITE_REPLACE = "()Ljava/lang/Object;";
+    private static final String DESCR_METHOD_WRITE_OBJECT = "(Ljava/io/ObjectOutputStream;)V";
+    private static final String DESCR_METHOD_READ_OBJECT = "(Ljava/io/ObjectInputStream;)V";
     private static final String NAME_METHOD_WRITE_REPLACE = "writeReplace";
+    private static final String NAME_METHOD_READ_OBJECT = "readObject";
+    private static final String NAME_METHOD_WRITE_OBJECT = "writeObject";
     private static final String DESCR_CTOR_SERIALIZED_LAMBDA
             = MethodType.methodType(void.class,
                                     Class.class,
@@ -65,6 +71,10 @@
                                     int.class, String.class, String.class, String.class,
                                     String.class,
                                     Object[].class).toMethodDescriptorString();
+    private static final String DESCR_CTOR_NOT_SERIALIZABLE_EXCEPTION
+            = MethodType.methodType(void.class, String.class).toMethodDescriptorString();
+    private static final String[] SER_HOSTILE_EXCEPTIONS = new String[] {NAME_NOT_SERIALIZABLE_EXCEPTION};
+
 
     // Used to ensure that each spun class name is unique
     private static final AtomicInteger counter = new AtomicInteger(0);
@@ -239,14 +249,16 @@
     private Class<?> spinInnerClass() throws LambdaConversionException {
         String[] interfaces;
         String samIntf = samBase.getName().replace('.', '/');
+        boolean accidentallySerializable = !isSerializable && Serializable.class.isAssignableFrom(samBase);
         if (markerInterfaces.length == 0) {
             interfaces = new String[]{samIntf};
         } else {
             // Assure no duplicate interfaces (ClassFormatError)
             Set<String> itfs = new LinkedHashSet<>(markerInterfaces.length + 1);
             itfs.add(samIntf);
-            for (int i = 0; i < markerInterfaces.length; i++) {
-                itfs.add(markerInterfaces[i].getName().replace('.', '/'));
+            for (Class<?> markerInterface : markerInterfaces) {
+                itfs.add(markerInterface.getName().replace('.', '/'));
+                accidentallySerializable |= !isSerializable && Serializable.class.isAssignableFrom(markerInterface);
             }
             interfaces = itfs.toArray(new String[itfs.size()]);
         }
@@ -283,7 +295,9 @@
         }
 
         if (isSerializable)
-            generateWriteReplace();
+            generateSerializationFriendlyMethods();
+        else if (accidentallySerializable)
+            generateSerializationHostileMethods();
 
         cw.visitEnd();
 
@@ -334,9 +348,9 @@
     }
 
     /**
-     * Generate the writeReplace method (if needed for serialization)
+     * Generate a writeReplace method that supports serialization
      */
-    private void generateWriteReplace() {
+    private void generateSerializationFriendlyMethods() {
         TypeConvertingMethodAdapter mv
                 = new TypeConvertingMethodAdapter(
                     cw.visitMethod(ACC_PRIVATE + ACC_FINAL,
@@ -376,6 +390,37 @@
     }
 
     /**
+     * Generate a readObject/writeObject method that is hostile to serialization
+     */
+    private void generateSerializationHostileMethods() {
+        MethodVisitor mv = cw.visitMethod(ACC_PRIVATE + ACC_FINAL,
+                                          NAME_METHOD_WRITE_OBJECT, DESCR_METHOD_WRITE_OBJECT,
+                                          null, SER_HOSTILE_EXCEPTIONS);
+        mv.visitCode();
+        mv.visitTypeInsn(NEW, NAME_NOT_SERIALIZABLE_EXCEPTION);
+        mv.visitInsn(DUP);
+        mv.visitLdcInsn("Non-serializable lambda");
+        mv.visitMethodInsn(INVOKESPECIAL, NAME_NOT_SERIALIZABLE_EXCEPTION, NAME_CTOR,
+                           DESCR_CTOR_NOT_SERIALIZABLE_EXCEPTION);
+        mv.visitInsn(ATHROW);
+        mv.visitMaxs(-1, -1);
+        mv.visitEnd();
+
+        mv = cw.visitMethod(ACC_PRIVATE + ACC_FINAL,
+                            NAME_METHOD_READ_OBJECT, DESCR_METHOD_READ_OBJECT,
+                            null, SER_HOSTILE_EXCEPTIONS);
+        mv.visitCode();
+        mv.visitTypeInsn(NEW, NAME_NOT_SERIALIZABLE_EXCEPTION);
+        mv.visitInsn(DUP);
+        mv.visitLdcInsn("Non-serializable lambda");
+        mv.visitMethodInsn(INVOKESPECIAL, NAME_NOT_SERIALIZABLE_EXCEPTION, NAME_CTOR,
+                           DESCR_CTOR_NOT_SERIALIZABLE_EXCEPTION);
+        mv.visitInsn(ATHROW);
+        mv.visitMaxs(-1, -1);
+        mv.visitEnd();
+    }
+
+    /**
      * This class generates a method body which calls the lambda implementation
      * method, converting arguments, as needed.
      */
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/lang/invoke/SerializedLambdaTest.java	Wed Oct 30 13:51:07 2013 -0700
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/lang/invoke/SerializedLambdaTest.java	Tue Oct 29 12:31:27 2013 -0400
@@ -22,9 +22,18 @@
  */
 package org.openjdk.tests.java.lang.invoke;
 
-import org.testng.annotations.Test;
-
-import java.io.*;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.NotSerializableException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.lang.invoke.CallSite;
+import java.lang.invoke.LambdaMetafactory;
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicLong;
@@ -34,6 +43,8 @@
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 
+import org.testng.annotations.Test;
+
 import static org.testng.Assert.assertFalse;
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
@@ -277,4 +288,66 @@
         LongConsumer lc = (LongConsumer & Serializable) a::addAndGet;
         assertSerial(lc, plc -> { plc.accept(3); });
     }
+
+    // Tests of direct use of metafactories
+
+    private static boolean foo(Object s) { return s != null && ((String) s).length() > 0; }
+    private static final MethodType predicateMT = MethodType.methodType(boolean.class, Object.class);
+    private static final MethodType stringPredicateMT = MethodType.methodType(boolean.class, String.class);
+    private static final Consumer<Predicate<String>> fooAsserter = x -> {
+        assertTrue(x.test("foo"));
+        assertFalse(x.test(""));
+        assertFalse(x.test(null));
+    };
+
+    // standard MF: nonserializable supertype
+    public void testDirectStdNonser() throws Throwable {
+        MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+        // Standard metafactory, non-serializable target: not serializable
+        CallSite cs = LambdaMetafactory.metafactory(MethodHandles.lookup(),
+                                                    "test", MethodType.methodType(Predicate.class),
+                                                    predicateMT, fooMH, stringPredicateMT);
+        Predicate<String> p = (Predicate<String>) cs.getTarget().invokeExact();
+        assertNotSerial(p, fooAsserter);
+    }
+
+    // standard MF: serializable supertype
+    public void testDirectStdSer() throws Throwable {
+        MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+        // Standard metafactory, serializable target: not serializable
+        CallSite cs = LambdaMetafactory.metafactory(MethodHandles.lookup(),
+                                                    "test", MethodType.methodType(SerPredicate.class),
+                                                    predicateMT, fooMH, stringPredicateMT);
+        assertNotSerial((SerPredicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+    }
+
+    // alt MF: nonserializable supertype
+    public void testAltStdNonser() throws Throwable {
+        MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+        // Alt metafactory, non-serializable target: not serializable
+        CallSite cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
+                                                       "test", MethodType.methodType(Predicate.class),
+                                                       predicateMT, fooMH, stringPredicateMT, 0);
+        assertNotSerial((Predicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+    }
+
+    // alt MF: serializable supertype
+    public void testAltStdSer() throws Throwable {
+        MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+        // Alt metafactory, serializable target, no FLAG_SERIALIZABLE: not serializable
+        CallSite cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
+                                                       "test", MethodType.methodType(SerPredicate.class),
+                                                       predicateMT, fooMH, stringPredicateMT, 0);
+        assertNotSerial((SerPredicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+
+        // Alt metafactory, serializable marker, no FLAG_SERIALIZABLE: not serializable
+        cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
+                                              "test", MethodType.methodType(Predicate.class),
+                                              predicateMT, fooMH, stringPredicateMT, LambdaMetafactory.FLAG_MARKERS, 1, Serializable.class);
+        assertNotSerial((Predicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+    }
 }