8172299: Improve class processing
authorrriggs
Fri, 03 Feb 2017 14:10:33 -0500
changeset 44756 b887cb49e622
parent 44755 e16b506a60b2
child 44757 c5f03e77cd67
8172299: Improve class processing Reviewed-by: coffeys, chegar, ahgross, skoivu, rhalade
jdk/src/java.base/share/classes/java/io/ObjectInputStream.java
jdk/test/java/io/Serializable/serialFilter/SerialFilterTest.java
--- a/jdk/src/java.base/share/classes/java/io/ObjectInputStream.java	Tue Dec 20 18:02:26 2016 +0000
+++ b/jdk/src/java.base/share/classes/java/io/ObjectInputStream.java	Fri Feb 03 14:10:33 2017 -0500
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -1797,12 +1797,19 @@
         } catch (ClassNotFoundException ex) {
             resolveEx = ex;
         }
+
+        // Call filterCheck on the class before reading anything else
+        filterCheck(cl, -1);
+
         skipCustomData();
 
-        desc.initProxy(cl, resolveEx, readClassDesc(false));
-
-        // Call filterCheck on the definition
-        filterCheck(desc.forClass(), -1);
+        try {
+            totalObjectRefs++;
+            depth++;
+            desc.initProxy(cl, resolveEx, readClassDesc(false));
+        } finally {
+            depth--;
+        }
 
         handles.finish(descHandle);
         passHandle = descHandle;
@@ -1847,12 +1854,19 @@
         } catch (ClassNotFoundException ex) {
             resolveEx = ex;
         }
+
+        // Call filterCheck on the class before reading anything else
+        filterCheck(cl, -1);
+
         skipCustomData();
 
-        desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false));
-
-        // Call filterCheck on the definition
-        filterCheck(desc.forClass(), -1);
+        try {
+            totalObjectRefs++;
+            depth++;
+            desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false));
+        } finally {
+            depth--;
+        }
 
         handles.finish(descHandle);
         passHandle = descHandle;
--- a/jdk/test/java/io/Serializable/serialFilter/SerialFilterTest.java	Tue Dec 20 18:02:26 2016 +0000
+++ b/jdk/test/java/io/Serializable/serialFilter/SerialFilterTest.java	Fri Feb 03 14:10:33 2017 -0500
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -26,18 +26,19 @@
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.InvalidClassException;
+import java.io.ObjectInputFilter;
 import java.io.ObjectInputStream;
-import java.io.ObjectInputFilter;
 import java.io.ObjectOutputStream;
 import java.io.Serializable;
 import java.lang.invoke.SerializedLambda;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Proxy;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Hashtable;
-import java.util.Set;
+import java.util.List;
 import java.util.concurrent.atomic.LongAdder;
 
 import javax.net.ssl.SSLEngineResult;
@@ -166,26 +167,33 @@
         Runnable runnable = (Runnable & Serializable) SerialFilterTest::noop;
         Object[][] objects = {
                 { null, 0, -1, 0, 0, 0,
-                        new HashSet<>()},        // no callback, no values
-                { objArray, 3, 7, 8, 2, 55,
-                        new HashSet<>(Arrays.asList(objArray.getClass()))},
-                { Object[].class, 1, -1, 1, 1, 40,
-                        new HashSet<>(Arrays.asList(Object[].class))},
-                { new SerialFilterTest(), 1, -1, 1, 1, 37,
-                        new HashSet<>(Arrays.asList(SerialFilterTest.class))},
-                { new LongAdder(), 2, -1, 1, 1, 93,
-                        new HashSet<>(Arrays.asList(LongAdder.class, serClass))},
-                { new byte[14], 2, 14, 1, 1, 27,
-                        new HashSet<>(Arrays.asList(byteArray.getClass()))},
-                { runnable, 13, 0, 10, 2, 514,
-                        new HashSet<>(Arrays.asList(java.lang.invoke.SerializedLambda.class,
+                        Arrays.asList()},        // no callback, no values
+                { objArray, 3, 7, 9, 2, 55,
+                        Arrays.asList(objArray.getClass(), objArray.getClass())},
+                { Object[].class, 1, -1, 1, 1, 38,
+                        Arrays.asList(Object[].class)},
+                { new SerialFilterTest(), 1, -1, 1, 1, 35,
+                        Arrays.asList(SerialFilterTest.class)},
+                { new LongAdder(), 2, -1, 2, 1, 93,
+                        Arrays.asList(serClass, LongAdder.class)},
+                { new byte[14], 2, 14, 2, 1, 27,
+                        Arrays.asList(byteArray.getClass(), byteArray.getClass())},
+                { runnable, 13, 0, 13, 2, 514,
+                        Arrays.asList(java.lang.invoke.SerializedLambda.class,
+                                objArray.getClass(),
+                                objArray.getClass(),
                                 SerialFilterTest.class,
-                                objArray.getClass()))},
-                { deepHashSet(10), 48, -1, 49, 11, 619,
-                        new HashSet<>(Arrays.asList(HashSet.class))},
-                { proxy.getClass(), 3, -1, 1, 1, 114,
-                        new HashSet<>(Arrays.asList(Runnable.class,
-                                java.lang.reflect.Proxy.class))},
+                                java.lang.invoke.SerializedLambda.class)},
+                { deepHashSet(10), 48, -1, 50, 11, 619,
+                        Arrays.asList(HashSet.class)},
+                { proxy.getClass(), 3, -1, 2, 2, 112,
+                        Arrays.asList(Runnable.class,
+                                java.lang.reflect.Proxy.class,
+                                java.lang.reflect.Proxy.class)},
+                { new F(), 6, -1, 6, 6, 202,
+                        Arrays.asList(F.class, E.class, D.class,
+                                C.class, B.class, A.class)},
+
         };
         return objects;
     }
@@ -224,11 +232,12 @@
     @Test(dataProvider="Objects")
     public static void t1(Object object,
                           long count, long maxArray, long maxRefs, long maxDepth, long maxBytes,
-                          Set<Class<?>> classes) throws IOException {
+                          List<Class<?>> classes) throws IOException {
         byte[] bytes = writeObjects(object);
         Validator validator = new Validator();
         validate(bytes, validator);
         System.out.printf("v: %s%n", validator);
+
         Assert.assertEquals(validator.count, count, "callback count wrong");
         Assert.assertEquals(validator.classes, classes, "classes mismatch");
         Assert.assertEquals(validator.maxArray, maxArray, "maxArray mismatch");
@@ -438,7 +447,7 @@
      */
     static class Validator implements ObjectInputFilter {
         long count;          // Count of calls to checkInput
-        HashSet<Class<?>> classes = new HashSet<>();
+        List<Class<?>> classes = new ArrayList<>();
         long maxArray = -1;
         long maxRefs;
         long maxDepth;
@@ -449,16 +458,20 @@
 
         @Override
         public ObjectInputFilter.Status checkInput(FilterInfo filter) {
+            Class<?> serialClass = filter.serialClass();
+            System.out.printf("     checkInput: class: %s, arrayLen: %d, refs: %d, depth: %d, bytes; %d%n",
+                    serialClass, filter.arrayLength(), filter.references(),
+                    filter.depth(), filter.streamBytes());
             count++;
-            if (filter.serialClass() != null) {
-                if (filter.serialClass().getName().contains("$$Lambda$")) {
+            if (serialClass != null) {
+                if (serialClass.getName().contains("$$Lambda$")) {
                     // TBD: proper identification of serialized Lambdas?
                     // Fold the serialized Lambda into the SerializedLambda type
                     classes.add(SerializedLambda.class);
-                } else if (Proxy.isProxyClass(filter.serialClass())) {
+                } else if (Proxy.isProxyClass(serialClass)) {
                     classes.add(Proxy.class);
                 } else {
-                    classes.add(filter.serialClass());
+                    classes.add(serialClass);
                 }
 
             }
@@ -626,7 +639,8 @@
             // a stream of exactly the size requested.
             return genMaxBytesObject(allowed, value);
         } else if (pattern.startsWith("maxrefs=")) {
-            Object[] array = new Object[allowed ? (int)value - 1 : (int)value];
+            // 4 references to classes in addition to the array contents
+            Object[] array = new Object[allowed ? (int)value - 4 : (int)value - 3];
             for (int i = 0; i < array.length; i++) {
                 array[i] = otherObject;
             }
@@ -775,4 +789,25 @@
             return streamBytes;
         }
     }
+
+    // Deeper superclass hierarchy
+    static class A implements Serializable {
+        private static final long serialVersionUID = 1L;
+    };
+    static class B extends A {
+        private static final long serialVersionUID = 2L;
+    }
+    static class C extends B {
+        private static final long serialVersionUID = 3L;
+    }
+    static class D extends C {
+        private static final long serialVersionUID = 4L;
+    }
+    static class E extends D {
+        private static final long serialVersionUID = 5L;
+    }
+    static class F extends E {
+        private static final long serialVersionUID = 6L;
+    }
+
 }