# HG changeset patch # User rriggs # Date 1486149033 18000 # Node ID b887cb49e622374b7697653359c4b99f7cd08a6b # Parent e16b506a60b2d1e356fdbe8466b36fdbe8686e8f 8172299: Improve class processing Reviewed-by: coffeys, chegar, ahgross, skoivu, rhalade diff -r e16b506a60b2 -r b887cb49e622 jdk/src/java.base/share/classes/java/io/ObjectInputStream.java --- a/jdk/src/java.base/share/classes/java/io/ObjectInputStream.java Tue Dec 20 18:02:26 2016 +0000 +++ b/jdk/src/java.base/share/classes/java/io/ObjectInputStream.java Fri Feb 03 14:10:33 2017 -0500 @@ -1,5 +1,5 @@ /* - * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -1797,12 +1797,19 @@ } catch (ClassNotFoundException ex) { resolveEx = ex; } + + // Call filterCheck on the class before reading anything else + filterCheck(cl, -1); + skipCustomData(); - desc.initProxy(cl, resolveEx, readClassDesc(false)); - - // Call filterCheck on the definition - filterCheck(desc.forClass(), -1); + try { + totalObjectRefs++; + depth++; + desc.initProxy(cl, resolveEx, readClassDesc(false)); + } finally { + depth--; + } handles.finish(descHandle); passHandle = descHandle; @@ -1847,12 +1854,19 @@ } catch (ClassNotFoundException ex) { resolveEx = ex; } + + // Call filterCheck on the class before reading anything else + filterCheck(cl, -1); + skipCustomData(); - desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false)); - - // Call filterCheck on the definition - filterCheck(desc.forClass(), -1); + try { + totalObjectRefs++; + depth++; + desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false)); + } finally { + depth--; + } handles.finish(descHandle); passHandle = descHandle; diff -r e16b506a60b2 -r b887cb49e622 jdk/test/java/io/Serializable/serialFilter/SerialFilterTest.java --- a/jdk/test/java/io/Serializable/serialFilter/SerialFilterTest.java Tue Dec 20 18:02:26 2016 +0000 +++ b/jdk/test/java/io/Serializable/serialFilter/SerialFilterTest.java Fri Feb 03 14:10:33 2017 -0500 @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,18 +26,19 @@ import java.io.EOFException; import java.io.IOException; import java.io.InvalidClassException; +import java.io.ObjectInputFilter; import java.io.ObjectInputStream; -import java.io.ObjectInputFilter; import java.io.ObjectOutputStream; import java.io.Serializable; import java.lang.invoke.SerializedLambda; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Proxy; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.Hashtable; -import java.util.Set; +import java.util.List; import java.util.concurrent.atomic.LongAdder; import javax.net.ssl.SSLEngineResult; @@ -166,26 +167,33 @@ Runnable runnable = (Runnable & Serializable) SerialFilterTest::noop; Object[][] objects = { { null, 0, -1, 0, 0, 0, - new HashSet<>()}, // no callback, no values - { objArray, 3, 7, 8, 2, 55, - new HashSet<>(Arrays.asList(objArray.getClass()))}, - { Object[].class, 1, -1, 1, 1, 40, - new HashSet<>(Arrays.asList(Object[].class))}, - { new SerialFilterTest(), 1, -1, 1, 1, 37, - new HashSet<>(Arrays.asList(SerialFilterTest.class))}, - { new LongAdder(), 2, -1, 1, 1, 93, - new HashSet<>(Arrays.asList(LongAdder.class, serClass))}, - { new byte[14], 2, 14, 1, 1, 27, - new HashSet<>(Arrays.asList(byteArray.getClass()))}, - { runnable, 13, 0, 10, 2, 514, - new HashSet<>(Arrays.asList(java.lang.invoke.SerializedLambda.class, + Arrays.asList()}, // no callback, no values + { objArray, 3, 7, 9, 2, 55, + Arrays.asList(objArray.getClass(), objArray.getClass())}, + { Object[].class, 1, -1, 1, 1, 38, + Arrays.asList(Object[].class)}, + { new SerialFilterTest(), 1, -1, 1, 1, 35, + Arrays.asList(SerialFilterTest.class)}, + { new LongAdder(), 2, -1, 2, 1, 93, + Arrays.asList(serClass, LongAdder.class)}, + { new byte[14], 2, 14, 2, 1, 27, + Arrays.asList(byteArray.getClass(), byteArray.getClass())}, + { runnable, 13, 0, 13, 2, 514, + Arrays.asList(java.lang.invoke.SerializedLambda.class, + objArray.getClass(), + objArray.getClass(), SerialFilterTest.class, - objArray.getClass()))}, - { deepHashSet(10), 48, -1, 49, 11, 619, - new HashSet<>(Arrays.asList(HashSet.class))}, - { proxy.getClass(), 3, -1, 1, 1, 114, - new HashSet<>(Arrays.asList(Runnable.class, - java.lang.reflect.Proxy.class))}, + java.lang.invoke.SerializedLambda.class)}, + { deepHashSet(10), 48, -1, 50, 11, 619, + Arrays.asList(HashSet.class)}, + { proxy.getClass(), 3, -1, 2, 2, 112, + Arrays.asList(Runnable.class, + java.lang.reflect.Proxy.class, + java.lang.reflect.Proxy.class)}, + { new F(), 6, -1, 6, 6, 202, + Arrays.asList(F.class, E.class, D.class, + C.class, B.class, A.class)}, + }; return objects; } @@ -224,11 +232,12 @@ @Test(dataProvider="Objects") public static void t1(Object object, long count, long maxArray, long maxRefs, long maxDepth, long maxBytes, - Set> classes) throws IOException { + List> classes) throws IOException { byte[] bytes = writeObjects(object); Validator validator = new Validator(); validate(bytes, validator); System.out.printf("v: %s%n", validator); + Assert.assertEquals(validator.count, count, "callback count wrong"); Assert.assertEquals(validator.classes, classes, "classes mismatch"); Assert.assertEquals(validator.maxArray, maxArray, "maxArray mismatch"); @@ -438,7 +447,7 @@ */ static class Validator implements ObjectInputFilter { long count; // Count of calls to checkInput - HashSet> classes = new HashSet<>(); + List> classes = new ArrayList<>(); long maxArray = -1; long maxRefs; long maxDepth; @@ -449,16 +458,20 @@ @Override public ObjectInputFilter.Status checkInput(FilterInfo filter) { + Class serialClass = filter.serialClass(); + System.out.printf(" checkInput: class: %s, arrayLen: %d, refs: %d, depth: %d, bytes; %d%n", + serialClass, filter.arrayLength(), filter.references(), + filter.depth(), filter.streamBytes()); count++; - if (filter.serialClass() != null) { - if (filter.serialClass().getName().contains("$$Lambda$")) { + if (serialClass != null) { + if (serialClass.getName().contains("$$Lambda$")) { // TBD: proper identification of serialized Lambdas? // Fold the serialized Lambda into the SerializedLambda type classes.add(SerializedLambda.class); - } else if (Proxy.isProxyClass(filter.serialClass())) { + } else if (Proxy.isProxyClass(serialClass)) { classes.add(Proxy.class); } else { - classes.add(filter.serialClass()); + classes.add(serialClass); } } @@ -626,7 +639,8 @@ // a stream of exactly the size requested. return genMaxBytesObject(allowed, value); } else if (pattern.startsWith("maxrefs=")) { - Object[] array = new Object[allowed ? (int)value - 1 : (int)value]; + // 4 references to classes in addition to the array contents + Object[] array = new Object[allowed ? (int)value - 4 : (int)value - 3]; for (int i = 0; i < array.length; i++) { array[i] = otherObject; } @@ -775,4 +789,25 @@ return streamBytes; } } + + // Deeper superclass hierarchy + static class A implements Serializable { + private static final long serialVersionUID = 1L; + }; + static class B extends A { + private static final long serialVersionUID = 2L; + } + static class C extends B { + private static final long serialVersionUID = 3L; + } + static class D extends C { + private static final long serialVersionUID = 4L; + } + static class E extends D { + private static final long serialVersionUID = 5L; + } + static class F extends E { + private static final long serialVersionUID = 6L; + } + }