8165492: Reduce number of lambda forms generated by MethodHandleInlineCopyStrategy
authorredestad
Mon, 12 Sep 2016 13:23:07 +0200
changeset 40810 b88d5910ea1e
parent 40809 7d4acf358119
child 40811 1fb2b94fa1d0
8165492: Reduce number of lambda forms generated by MethodHandleInlineCopyStrategy Reviewed-by: mhaupt, vlivanov, psandoz, shade
jdk/src/java.base/share/classes/java/lang/StringConcatHelper.java
jdk/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java
jdk/src/java.base/share/classes/java/lang/invoke/MethodHandles.java
jdk/src/java.base/share/classes/java/lang/invoke/StringConcatFactory.java
--- a/jdk/src/java.base/share/classes/java/lang/StringConcatHelper.java	Sun Sep 11 13:23:14 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/lang/StringConcatHelper.java	Mon Sep 12 13:23:07 2016 +0200
@@ -334,11 +334,15 @@
     /**
      * Instantiates the String with given buffer and coder
      * @param buf     buffer to use
+     * @param index   remaining index
      * @param coder   coder to use
      * @return String resulting string
      */
-    static String newString(byte[] buf, byte coder) {
+    static String newString(byte[] buf, int index, byte coder) {
         // Use the private, non-copying constructor (unsafe!)
+        if (index != 0) {
+            throw new InternalError("Storage is not completely initialized, " + index + " bytes left");
+        }
         return new String(buf, coder);
     }
 
--- a/jdk/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java	Sun Sep 11 13:23:14 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java	Mon Sep 12 13:23:07 2016 +0200
@@ -83,7 +83,9 @@
                 FOLD_ARGS = 11,
                 FOLD_ARGS_TO_VOID = 12,
                 PERMUTE_ARGS = 13,
-                LOCAL_TYPES = 14;
+                LOCAL_TYPES = 14,
+                FOLD_SELECT_ARGS = 15,
+                FOLD_SELECT_ARGS_TO_VOID = 16;
 
         private static final boolean STRESS_TEST = false; // turn on to disable most packing
         private static final int
@@ -695,6 +697,72 @@
         return buf.endEdit();
     }
 
+    private LambdaForm makeArgumentCombinationForm(int pos,
+                                                   MethodType combinerType,
+                                                   int[] argPositions,
+                                                   boolean keepArguments,
+                                                   boolean dropResult) {
+        LambdaFormBuffer buf = buffer();
+        buf.startEdit();
+        int combinerArity = combinerType.parameterCount();
+        assert(combinerArity == argPositions.length);
+
+        int resultArity = (dropResult ? 0 : 1);
+
+        assert(pos <= lambdaForm.arity);
+        assert(pos > 0);  // cannot filter the MH arg itself
+        assert(combinerType == combinerType.basicType());
+        assert(combinerType.returnType() != void.class || dropResult);
+
+        BoundMethodHandle.SpeciesData oldData = oldSpeciesData();
+        BoundMethodHandle.SpeciesData newData = newSpeciesData(L_TYPE);
+
+        // The newly created LF will run with a different BMH.
+        // Switch over any pre-existing BMH field references to the new BMH class.
+        Name oldBaseAddress = lambdaForm.parameter(0);  // BMH holding the values
+        buf.replaceFunctions(oldData.getterFunctions(), newData.getterFunctions(), oldBaseAddress);
+        Name newBaseAddress = oldBaseAddress.withConstraint(newData);
+        buf.renameParameter(0, newBaseAddress);
+
+        Name getCombiner = new Name(newData.getterFunction(oldData.fieldCount()), newBaseAddress);
+        Object[] combinerArgs = new Object[1 + combinerArity];
+        combinerArgs[0] = getCombiner;
+        Name[] newParams;
+        if (keepArguments) {
+            newParams = new Name[0];
+            for (int i = 0; i < combinerArity; i++) {
+                combinerArgs[i + 1] = lambdaForm.parameter(1 + argPositions[i]);
+                assert (basicType(combinerType.parameterType(i)) == lambdaForm.parameterType(1 + argPositions[i]));
+            }
+        } else {
+            newParams = new Name[combinerArity];
+            for (int i = 0; i < newParams.length; i++) {
+                newParams[i] = lambdaForm.parameter(1 + argPositions[i]);
+                assert (basicType(combinerType.parameterType(i)) == lambdaForm.parameterType(1 + argPositions[i]));
+            }
+            System.arraycopy(newParams, 0,
+                             combinerArgs, 1, combinerArity);
+        }
+        Name callCombiner = new Name(combinerType, combinerArgs);
+
+        // insert the two new expressions
+        int exprPos = lambdaForm.arity();
+        buf.insertExpression(exprPos+0, getCombiner);
+        buf.insertExpression(exprPos+1, callCombiner);
+
+        // insert new arguments, if needed
+        int argPos = pos + resultArity;  // skip result parameter
+        for (Name newParam : newParams) {
+            buf.insertParameter(argPos++, newParam);
+        }
+        assert(buf.lastIndexOf(callCombiner) == exprPos+1+newParams.length);
+        if (!dropResult) {
+            buf.replaceParameterByCopy(pos, exprPos+1+newParams.length);
+        }
+
+        return buf.endEdit();
+    }
+
     LambdaForm filterReturnForm(BasicType newType, boolean constantZero) {
         byte kind = (constantZero ? Transform.FILTER_RETURN_TO_ZERO : Transform.FILTER_RETURN);
         Transform key = Transform.of(kind, newType.ordinal());
@@ -759,6 +827,21 @@
         return putInCache(key, form);
     }
 
+    LambdaForm foldArgumentsForm(int foldPos, boolean dropResult, MethodType combinerType, int ... argPositions) {
+        byte kind = (dropResult ? Transform.FOLD_SELECT_ARGS_TO_VOID
+                                : Transform.FOLD_SELECT_ARGS);
+        int[] keyArgs = Arrays.copyOf(argPositions, argPositions.length + 1);
+        keyArgs[argPositions.length] = foldPos;
+        Transform key = Transform.of(kind, keyArgs);
+        LambdaForm form = getInCache(key);
+        if (form != null) {
+            assert(form.arity == lambdaForm.arity - (kind == Transform.FOLD_SELECT_ARGS ? 1 : 0));
+            return form;
+        }
+        form = makeArgumentCombinationForm(foldPos, combinerType, argPositions, true, dropResult);
+        return putInCache(key, form);
+    }
+
     LambdaForm permuteArgumentsForm(int skip, int[] reorder) {
         assert(skip == 1);  // skip only the leading MH argument, names[0]
         int length = lambdaForm.names.length;
--- a/jdk/src/java.base/share/classes/java/lang/invoke/MethodHandles.java	Sun Sep 11 13:23:14 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/lang/invoke/MethodHandles.java	Mon Sep 12 13:23:07 2016 +0200
@@ -3943,6 +3943,33 @@
         return rtype;
     }
 
+    private static Class<?> foldArgumentChecks(int foldPos, MethodType targetType, MethodType combinerType, int ... argPos) {
+        int foldArgs = combinerType.parameterCount();
+        if (argPos.length != foldArgs) {
+            throw newIllegalArgumentException("combiner and argument map must be equal size", combinerType, argPos.length);
+        }
+        Class<?> rtype = combinerType.returnType();
+        int foldVals = rtype == void.class ? 0 : 1;
+        boolean ok = true;
+        for (int i = 0; i < foldArgs; i++) {
+            int arg = argPos[i];
+            if (arg < 0 || arg > targetType.parameterCount()) {
+                throw newIllegalArgumentException("arg outside of target parameterRange", targetType, arg);
+            }
+            if (combinerType.parameterType(i) != targetType.parameterType(arg)) {
+                throw newIllegalArgumentException("target argument type at position " + arg
+                        + " must match combiner argument type at index " + i + ": " + targetType
+                        + " -> " + combinerType + ", map: " + Arrays.toString(argPos));
+            }
+        }
+        if (ok && foldVals != 0 && combinerType.returnType() != targetType.parameterType(foldPos)) {
+            ok = false;
+        }
+        if (!ok)
+            throw misMatchedTypes("target and combiner types", targetType, combinerType);
+        return rtype;
+    }
+
     /**
      * Makes a method handle which adapts a target method handle,
      * by guarding it with a test, a boolean-valued method handle.
@@ -4949,6 +4976,27 @@
         return result;
     }
 
+    /**
+     * As {@see foldArguments(MethodHandle, int, MethodHandle)}, but with the
+     * added capability of selecting the arguments from the targets parameters
+     * to call the combiner with. This allows us to avoid some simple cases of
+     * permutations and padding the combiner with dropArguments to select the
+     * right argument, which may ultimately produce fewer intermediaries.
+     */
+    static MethodHandle foldArguments(MethodHandle target, int pos, MethodHandle combiner, int ... argPositions) {
+        MethodType targetType = target.type();
+        MethodType combinerType = combiner.type();
+        Class<?> rtype = foldArgumentChecks(pos, targetType, combinerType, argPositions);
+        BoundMethodHandle result = target.rebind();
+        boolean dropResult = rtype == void.class;
+        LambdaForm lform = result.editor().foldArgumentsForm(1 + pos, dropResult, combinerType.basicType(), argPositions);
+        MethodType newType = targetType;
+        if (!dropResult) {
+            newType = newType.dropParameterTypes(pos, pos + 1);
+        }
+        result = result.copyWithExtendL(newType, lform, combiner);
+        return result;
+    }
 
     private static void checkLoop0(MethodHandle[][] clauses) {
         if (clauses == null || clauses.length == 0) {
--- a/jdk/src/java.base/share/classes/java/lang/invoke/StringConcatFactory.java	Sun Sep 11 13:23:14 2016 -0700
+++ b/jdk/src/java.base/share/classes/java/lang/invoke/StringConcatFactory.java	Mon Sep 12 13:23:07 2016 +0200
@@ -563,9 +563,8 @@
         }
 
         if ((lookup.lookupModes() & MethodHandles.Lookup.PRIVATE) == 0) {
-            throw new StringConcatException(String.format(
-                    "Invalid caller: %s",
-                    lookup.lookupClass().getName()));
+            throw new StringConcatException("Invalid caller: " +
+                    lookup.lookupClass().getName());
         }
 
         int cCount = 0;
@@ -1494,51 +1493,41 @@
             // Drop all remaining parameter types, leave only helper arguments:
             MethodHandle mh;
 
-            mh = MethodHandles.dropArguments(NEW_STRING, 2, ptypes);
-            mh = MethodHandles.dropArguments(mh, 0, int.class);
+            mh = MethodHandles.dropArguments(NEW_STRING, 3, ptypes);
 
-            // Safety: check that remaining index is zero -- that would mean the storage is completely
-            // overwritten, and no leakage of uninitialized data occurred.
-            mh = MethodHandles.filterArgument(mh, 0, CHECK_INDEX);
-
-            // Mix in prependers. This happens when (int, byte[], byte) = (index, storage, coder) is already
+            // Mix in prependers. This happens when (byte[], int, byte) = (storage, index, coder) is already
             // known from the combinators below. We are assembling the string backwards, so "index" is the
             // *ending* index.
             for (RecipeElement el : recipe.getElements()) {
-                MethodHandle prepender;
+                // Do the prepend, and put "new" index at index 1
+                mh = MethodHandles.dropArguments(mh, 2, int.class);
                 switch (el.getTag()) {
-                    case TAG_CONST:
+                    case TAG_CONST: {
                         Object cnst = el.getValue();
-                        prepender = MethodHandles.insertArguments(prepender(cnst.getClass()), 3, cnst);
+                        MethodHandle prepender = MethodHandles.insertArguments(prepender(cnst.getClass()), 3, cnst);
+                        mh = MethodHandles.foldArguments(mh, 1, prepender,
+                                2, 0, 3 // index, storage, coder
+                        );
                         break;
-                    case TAG_ARG:
+                    }
+                    case TAG_ARG: {
                         int pos = el.getArgPos();
-                        prepender = selectArgument(prepender(ptypes[pos]), 3, ptypes, pos);
+                        MethodHandle prepender = prepender(ptypes[pos]);
+                        mh = MethodHandles.foldArguments(mh, 1, prepender,
+                                2, 0, 3, // index, storage, coder
+                                4 + pos  // selected argument
+                        );
                         break;
+                    }
                     default:
                         throw new StringConcatException("Unhandled tag: " + el.getTag());
                 }
-
-                // Remove "old" index from arguments
-                mh = MethodHandles.dropArguments(mh, 1, int.class);
-
-                // Do the prepend, and put "new" index at index 0
-                mh = MethodHandles.foldArguments(mh, prepender);
             }
 
-            // Prepare the argument list for prepending. The tree below would instantiate
-            // the storage byte[] into argument 0, so we need to swap "storage" and "index".
-            // The index at this point equals to "size", and resides at argument 1.
-            {
-                MethodType nmt = mh.type()
-                        .changeParameterType(0, byte[].class)
-                        .changeParameterType(1, int.class);
-                mh = MethodHandles.permuteArguments(mh, nmt, swap10(nmt.parameterCount()));
-            }
-
-            // Fold in byte[] instantiation at argument 0.
-            MethodHandle combiner = MethodHandles.dropArguments(NEW_ARRAY, 2, ptypes);
-            mh = MethodHandles.foldArguments(mh, combiner);
+            // Fold in byte[] instantiation at argument 0
+            mh = MethodHandles.foldArguments(mh, 0, NEW_ARRAY,
+                    1, 2 // index, coder
+            );
 
             // Start combining length and coder mixers.
             //
@@ -1567,12 +1556,8 @@
                         int ac = el.getArgPos();
 
                         Class<?> argClass = ptypes[ac];
-                        MethodHandle lm = selectArgument(lengthMixer(argClass), 1, ptypes, ac);
-                        lm = MethodHandles.dropArguments(lm, 0, byte.class); // (*)
-                        lm = MethodHandles.dropArguments(lm, 2, byte.class);
-
-                        MethodHandle cm = selectArgument(coderMixer(argClass),  1, ptypes, ac);
-                        cm = MethodHandles.dropArguments(cm, 0, int.class);  // (**)
+                        MethodHandle lm = lengthMixer(argClass);
+                        MethodHandle cm = coderMixer(argClass);
 
                         // Read this bottom up:
 
@@ -1580,12 +1565,18 @@
                         mh = MethodHandles.dropArguments(mh, 2, int.class, byte.class);
 
                         // 3. Compute "new-index", producing ("new-index", "new-coder", "old-index", "old-coder", <args>)
-                        //    Length mixer ignores both "new-coder" and "old-coder" due to dropArguments above (*)
-                        mh = MethodHandles.foldArguments(mh, lm);
+                        //    Length mixer needs old index, plus the appropriate argument
+                        mh = MethodHandles.foldArguments(mh, 0, lm,
+                                2, // old-index
+                                4 + ac // selected argument
+                        );
 
                         // 2. Compute "new-coder", producing ("new-coder", "old-index", "old-coder", <args>)
-                        //    Coder mixer ignores the "old-index" arg due to dropArguments above (**)
-                        mh = MethodHandles.foldArguments(mh, cm);
+                        //    Coder mixer needs old coder, plus the appropriate argument.
+                        mh = MethodHandles.foldArguments(mh, 0, cm,
+                                2, // old-coder
+                                3 + ac // selected argument
+                        );
 
                         // 1. The mh shape here is ("old-index", "old-coder", <args>)
                         break;
@@ -1606,41 +1597,11 @@
             return mh;
         }
 
-        private static int[] swap10(int count) {
-            int[] perm = new int[count];
-            perm[0] = 1;
-            perm[1] = 0;
-            for (int i = 2; i < count; i++) {
-                perm[i] = i;
-            }
-            return perm;
-        }
-
-        // Adapts: (...prefix..., parameter[pos])R -> (...prefix..., ...parameters...)R
-        private static MethodHandle selectArgument(MethodHandle mh, int prefix, Class<?>[] ptypes, int pos) {
-            if (pos == 0) {
-                return MethodHandles.dropArguments(mh, prefix + 1, Arrays.copyOfRange(ptypes, 1, ptypes.length));
-            } else if (pos == ptypes.length - 1) {
-                return MethodHandles.dropArguments(mh, prefix, Arrays.copyOf(ptypes, ptypes.length - 1));
-            } else { // 0 < pos < ptypes.size() - 1
-                MethodHandle t = MethodHandles.dropArguments(mh, prefix, Arrays.copyOf(ptypes, pos));
-                return MethodHandles.dropArguments(t, prefix + 1 + pos, Arrays.copyOfRange(ptypes, pos + 1, ptypes.length));
-            }
-        }
-
         @ForceInline
         private static byte[] newArray(int length, byte coder) {
             return (byte[]) UNSAFE.allocateUninitializedArray(byte.class, length << coder);
         }
 
-        @ForceInline
-        private static int checkIndex(int index) {
-            if (index != 0) {
-                throw new IllegalStateException("Storage is not completely initialized, " + index + " bytes left");
-            }
-            return index;
-        }
-
         private static MethodHandle prepender(Class<?> cl) {
             return PREPENDERS.computeIfAbsent(cl, PREPEND);
         }
@@ -1678,7 +1639,6 @@
         };
 
         private static final MethodHandle NEW_STRING;
-        private static final MethodHandle CHECK_INDEX;
         private static final MethodHandle NEW_ARRAY;
         private static final ConcurrentMap<Class<?>, MethodHandle> PREPENDERS;
         private static final ConcurrentMap<Class<?>, MethodHandle> LENGTH_MIXERS;
@@ -1699,9 +1659,8 @@
             LENGTH_MIXERS = new ConcurrentHashMap<>();
             CODER_MIXERS = new ConcurrentHashMap<>();
 
-            NEW_STRING = lookupStatic(Lookup.IMPL_LOOKUP, STRING_HELPER, "newString", String.class, byte[].class, byte.class);
+            NEW_STRING = lookupStatic(Lookup.IMPL_LOOKUP, STRING_HELPER, "newString", String.class, byte[].class, int.class, byte.class);
             NEW_ARRAY  = lookupStatic(Lookup.IMPL_LOOKUP, MethodHandleInlineCopyStrategy.class, "newArray", byte[].class, int.class, byte.class);
-            CHECK_INDEX = lookupStatic(Lookup.IMPL_LOOKUP, MethodHandleInlineCopyStrategy.class, "checkIndex", int.class, int.class);
         }
     }