8027318: Lambda Metafactory: generate serialization-hostile read/writeObject methods for non-serializable lambdas
Reviewed-by: rfield, psandoz
--- a/jdk/src/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java Wed Oct 30 13:51:07 2013 -0700
+++ b/jdk/src/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java Tue Oct 29 12:31:27 2013 -0400
@@ -30,6 +30,7 @@
import sun.security.action.GetPropertyAction;
import java.io.FilePermission;
+import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedAction;
@@ -56,8 +57,13 @@
//Serialization support
private static final String NAME_SERIALIZED_LAMBDA = "java/lang/invoke/SerializedLambda";
+ private static final String NAME_NOT_SERIALIZABLE_EXCEPTION = "java/io/NotSerializableException";
private static final String DESCR_METHOD_WRITE_REPLACE = "()Ljava/lang/Object;";
+ private static final String DESCR_METHOD_WRITE_OBJECT = "(Ljava/io/ObjectOutputStream;)V";
+ private static final String DESCR_METHOD_READ_OBJECT = "(Ljava/io/ObjectInputStream;)V";
private static final String NAME_METHOD_WRITE_REPLACE = "writeReplace";
+ private static final String NAME_METHOD_READ_OBJECT = "readObject";
+ private static final String NAME_METHOD_WRITE_OBJECT = "writeObject";
private static final String DESCR_CTOR_SERIALIZED_LAMBDA
= MethodType.methodType(void.class,
Class.class,
@@ -65,6 +71,10 @@
int.class, String.class, String.class, String.class,
String.class,
Object[].class).toMethodDescriptorString();
+ private static final String DESCR_CTOR_NOT_SERIALIZABLE_EXCEPTION
+ = MethodType.methodType(void.class, String.class).toMethodDescriptorString();
+ private static final String[] SER_HOSTILE_EXCEPTIONS = new String[] {NAME_NOT_SERIALIZABLE_EXCEPTION};
+
// Used to ensure that each spun class name is unique
private static final AtomicInteger counter = new AtomicInteger(0);
@@ -239,14 +249,16 @@
private Class<?> spinInnerClass() throws LambdaConversionException {
String[] interfaces;
String samIntf = samBase.getName().replace('.', '/');
+ boolean accidentallySerializable = !isSerializable && Serializable.class.isAssignableFrom(samBase);
if (markerInterfaces.length == 0) {
interfaces = new String[]{samIntf};
} else {
// Assure no duplicate interfaces (ClassFormatError)
Set<String> itfs = new LinkedHashSet<>(markerInterfaces.length + 1);
itfs.add(samIntf);
- for (int i = 0; i < markerInterfaces.length; i++) {
- itfs.add(markerInterfaces[i].getName().replace('.', '/'));
+ for (Class<?> markerInterface : markerInterfaces) {
+ itfs.add(markerInterface.getName().replace('.', '/'));
+ accidentallySerializable |= !isSerializable && Serializable.class.isAssignableFrom(markerInterface);
}
interfaces = itfs.toArray(new String[itfs.size()]);
}
@@ -283,7 +295,9 @@
}
if (isSerializable)
- generateWriteReplace();
+ generateSerializationFriendlyMethods();
+ else if (accidentallySerializable)
+ generateSerializationHostileMethods();
cw.visitEnd();
@@ -334,9 +348,9 @@
}
/**
- * Generate the writeReplace method (if needed for serialization)
+ * Generate a writeReplace method that supports serialization
*/
- private void generateWriteReplace() {
+ private void generateSerializationFriendlyMethods() {
TypeConvertingMethodAdapter mv
= new TypeConvertingMethodAdapter(
cw.visitMethod(ACC_PRIVATE + ACC_FINAL,
@@ -376,6 +390,37 @@
}
/**
+ * Generate a readObject/writeObject method that is hostile to serialization
+ */
+ private void generateSerializationHostileMethods() {
+ MethodVisitor mv = cw.visitMethod(ACC_PRIVATE + ACC_FINAL,
+ NAME_METHOD_WRITE_OBJECT, DESCR_METHOD_WRITE_OBJECT,
+ null, SER_HOSTILE_EXCEPTIONS);
+ mv.visitCode();
+ mv.visitTypeInsn(NEW, NAME_NOT_SERIALIZABLE_EXCEPTION);
+ mv.visitInsn(DUP);
+ mv.visitLdcInsn("Non-serializable lambda");
+ mv.visitMethodInsn(INVOKESPECIAL, NAME_NOT_SERIALIZABLE_EXCEPTION, NAME_CTOR,
+ DESCR_CTOR_NOT_SERIALIZABLE_EXCEPTION);
+ mv.visitInsn(ATHROW);
+ mv.visitMaxs(-1, -1);
+ mv.visitEnd();
+
+ mv = cw.visitMethod(ACC_PRIVATE + ACC_FINAL,
+ NAME_METHOD_READ_OBJECT, DESCR_METHOD_READ_OBJECT,
+ null, SER_HOSTILE_EXCEPTIONS);
+ mv.visitCode();
+ mv.visitTypeInsn(NEW, NAME_NOT_SERIALIZABLE_EXCEPTION);
+ mv.visitInsn(DUP);
+ mv.visitLdcInsn("Non-serializable lambda");
+ mv.visitMethodInsn(INVOKESPECIAL, NAME_NOT_SERIALIZABLE_EXCEPTION, NAME_CTOR,
+ DESCR_CTOR_NOT_SERIALIZABLE_EXCEPTION);
+ mv.visitInsn(ATHROW);
+ mv.visitMaxs(-1, -1);
+ mv.visitEnd();
+ }
+
+ /**
* This class generates a method body which calls the lambda implementation
* method, converting arguments, as needed.
*/
--- a/jdk/test/java/util/stream/test/org/openjdk/tests/java/lang/invoke/SerializedLambdaTest.java Wed Oct 30 13:51:07 2013 -0700
+++ b/jdk/test/java/util/stream/test/org/openjdk/tests/java/lang/invoke/SerializedLambdaTest.java Tue Oct 29 12:31:27 2013 -0400
@@ -22,9 +22,18 @@
*/
package org.openjdk.tests.java.lang.invoke;
-import org.testng.annotations.Test;
-
-import java.io.*;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.NotSerializableException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.lang.invoke.CallSite;
+import java.lang.invoke.LambdaMetafactory;
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
@@ -34,6 +43,8 @@
import java.util.function.Predicate;
import java.util.function.Supplier;
+import org.testng.annotations.Test;
+
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
@@ -277,4 +288,66 @@
LongConsumer lc = (LongConsumer & Serializable) a::addAndGet;
assertSerial(lc, plc -> { plc.accept(3); });
}
+
+ // Tests of direct use of metafactories
+
+ private static boolean foo(Object s) { return s != null && ((String) s).length() > 0; }
+ private static final MethodType predicateMT = MethodType.methodType(boolean.class, Object.class);
+ private static final MethodType stringPredicateMT = MethodType.methodType(boolean.class, String.class);
+ private static final Consumer<Predicate<String>> fooAsserter = x -> {
+ assertTrue(x.test("foo"));
+ assertFalse(x.test(""));
+ assertFalse(x.test(null));
+ };
+
+ // standard MF: nonserializable supertype
+ public void testDirectStdNonser() throws Throwable {
+ MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+ // Standard metafactory, non-serializable target: not serializable
+ CallSite cs = LambdaMetafactory.metafactory(MethodHandles.lookup(),
+ "test", MethodType.methodType(Predicate.class),
+ predicateMT, fooMH, stringPredicateMT);
+ Predicate<String> p = (Predicate<String>) cs.getTarget().invokeExact();
+ assertNotSerial(p, fooAsserter);
+ }
+
+ // standard MF: serializable supertype
+ public void testDirectStdSer() throws Throwable {
+ MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+ // Standard metafactory, serializable target: not serializable
+ CallSite cs = LambdaMetafactory.metafactory(MethodHandles.lookup(),
+ "test", MethodType.methodType(SerPredicate.class),
+ predicateMT, fooMH, stringPredicateMT);
+ assertNotSerial((SerPredicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+ }
+
+ // alt MF: nonserializable supertype
+ public void testAltStdNonser() throws Throwable {
+ MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+ // Alt metafactory, non-serializable target: not serializable
+ CallSite cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
+ "test", MethodType.methodType(Predicate.class),
+ predicateMT, fooMH, stringPredicateMT, 0);
+ assertNotSerial((Predicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+ }
+
+ // alt MF: serializable supertype
+ public void testAltStdSer() throws Throwable {
+ MethodHandle fooMH = MethodHandles.lookup().findStatic(SerializedLambdaTest.class, "foo", predicateMT);
+
+ // Alt metafactory, serializable target, no FLAG_SERIALIZABLE: not serializable
+ CallSite cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
+ "test", MethodType.methodType(SerPredicate.class),
+ predicateMT, fooMH, stringPredicateMT, 0);
+ assertNotSerial((SerPredicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+
+ // Alt metafactory, serializable marker, no FLAG_SERIALIZABLE: not serializable
+ cs = LambdaMetafactory.altMetafactory(MethodHandles.lookup(),
+ "test", MethodType.methodType(Predicate.class),
+ predicateMT, fooMH, stringPredicateMT, LambdaMetafactory.FLAG_MARKERS, 1, Serializable.class);
+ assertNotSerial((Predicate<String>) cs.getTarget().invokeExact(), fooAsserter);
+ }
}