8213478: Reduce rebinds when applying repeated filters and conversions
authorredestad
Sun, 11 Nov 2018 21:24:46 +0100
changeset 52486 6f5948597697
parent 52485 e5534cc91a10
child 52487 5d1d07b72f15
child 52551 339963bcff24
8213478: Reduce rebinds when applying repeated filters and conversions Reviewed-by: vlivanov, jrose
src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java
src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java
src/java.base/share/classes/java/lang/invoke/MethodHandles.java
--- a/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java	Sat Nov 10 20:47:28 2018 +0100
+++ b/src/java.base/share/classes/java/lang/invoke/LambdaFormEditor.java	Sun Nov 11 21:24:46 2018 +0100
@@ -30,6 +30,8 @@
 import java.lang.ref.SoftReference;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
+import java.util.TreeMap;
 import java.util.concurrent.ConcurrentHashMap;
 
 import static java.lang.invoke.LambdaForm.*;
@@ -86,7 +88,8 @@
                 LOCAL_TYPES = 14,
                 FOLD_SELECT_ARGS = 15,
                 FOLD_SELECT_ARGS_TO_VOID = 16,
-                FILTER_SELECT_ARGS = 17;
+                FILTER_SELECT_ARGS = 17,
+                REPEAT_FILTER_ARGS = 18;
 
         private static final boolean STRESS_TEST = false; // turn on to disable most packing
         private static final int
@@ -641,6 +644,104 @@
         return putInCache(key, form);
     }
 
+    /**
+     * This creates a LF that will repeatedly invoke some unary filter function
+     * at each of the given positions. This allows fewer LFs and BMH species
+     * classes to be generated in typical cases compared to building up the form
+     * by reapplying of {@code filterArgumentForm(int,BasicType)}, and should do
+     * no worse in the worst case.
+     */
+    LambdaForm filterRepeatedArgumentForm(BasicType newType, int... argPositions) {
+        assert (argPositions.length > 1);
+        byte[] keyArgs = new byte[argPositions.length + 2];
+        keyArgs[0] = Transform.REPEAT_FILTER_ARGS;
+        keyArgs[argPositions.length + 1] = (byte)newType.ordinal();
+        for (int i = 0; i < argPositions.length; i++) {
+            keyArgs[i + 1] = (byte)argPositions[i];
+        }
+        Transform key = new Transform(keyArgs);
+        LambdaForm form = getInCache(key);
+        if (form != null) {
+            assert(form.arity == lambdaForm.arity &&
+                    formParametersMatch(form, newType, argPositions));
+            return form;
+        }
+        BasicType oldType = lambdaForm.parameterType(argPositions[0]);
+        MethodType filterType = MethodType.methodType(oldType.basicTypeClass(),
+                newType.basicTypeClass());
+        form = makeRepeatedFilterForm(filterType, argPositions);
+        assert (formParametersMatch(form, newType, argPositions));
+        return putInCache(key, form);
+    }
+
+    private boolean formParametersMatch(LambdaForm form, BasicType newType, int... argPositions) {
+        for (int i : argPositions) {
+            if (form.parameterType(i) != newType) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    private LambdaForm makeRepeatedFilterForm(MethodType combinerType, int... positions) {
+        assert (combinerType.parameterCount() == 1 &&
+                combinerType == combinerType.basicType() &&
+                combinerType.returnType() != void.class);
+        LambdaFormBuffer buf = buffer();
+        buf.startEdit();
+
+        BoundMethodHandle.SpeciesData oldData = oldSpeciesData();
+        BoundMethodHandle.SpeciesData newData = newSpeciesData(L_TYPE);
+
+        // The newly created LF will run with a different BMH.
+        // Switch over any pre-existing BMH field references to the new BMH class.
+        Name oldBaseAddress = lambdaForm.parameter(0);  // BMH holding the values
+        buf.replaceFunctions(oldData.getterFunctions(), newData.getterFunctions(), oldBaseAddress);
+        Name newBaseAddress = oldBaseAddress.withConstraint(newData);
+        buf.renameParameter(0, newBaseAddress);
+
+        // Insert the new expressions at the end
+        int exprPos = lambdaForm.arity();
+        Name getCombiner = new Name(newData.getterFunction(oldData.fieldCount()), newBaseAddress);
+        buf.insertExpression(exprPos++, getCombiner);
+
+        // After inserting expressions, we insert parameters in order
+        // from lowest to highest, simplifying the calculation of where parameters
+        // and expressions are
+        var newParameters = new TreeMap<Name, Integer>(new Comparator<>() {
+            public int compare(Name n1, Name n2) {
+                return n1.index - n2.index;
+            }
+        });
+
+        // Insert combiner expressions in reverse order so that the invocation of
+        // the resulting form will invoke the combiners in left-to-right order
+        for (int i = positions.length - 1; i >= 0; --i) {
+            int pos = positions[i];
+            assert (pos > 0 && pos <= MethodType.MAX_JVM_ARITY && pos < lambdaForm.arity);
+
+            Name newParameter = new Name(pos, basicType(combinerType.parameterType(0)));
+            Object[] combinerArgs = {getCombiner, newParameter};
+
+            Name callCombiner = new Name(combinerType, combinerArgs);
+            buf.insertExpression(exprPos++, callCombiner);
+            newParameters.put(newParameter, exprPos);
+        }
+
+        // Mix in new parameters from left to right in the buffer (this doesn't change
+        // execution order
+        int offset = 0;
+        for (var entry : newParameters.entrySet()) {
+            Name newParameter = entry.getKey();
+            int from = entry.getValue();
+            buf.insertParameter(newParameter.index() + 1 + offset, newParameter);
+            buf.replaceParameterByCopy(newParameter.index() + offset, from + offset);
+            offset++;
+        }
+        return buf.endEdit();
+    }
+
+
     private LambdaForm makeArgumentCombinationForm(int pos,
                                                    MethodType combinerType,
                                                    boolean keepArguments, boolean dropResult) {
--- a/src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java	Sat Nov 10 20:47:28 2018 +0100
+++ b/src/java.base/share/classes/java/lang/invoke/MethodHandleImpl.java	Sun Nov 11 21:24:46 2018 +0100
@@ -42,6 +42,7 @@
 import java.lang.reflect.Array;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -257,14 +258,19 @@
 
     private static int countNonNull(Object[] array) {
         int count = 0;
-        for (Object x : array) {
-            if (x != null)  ++count;
+        if (array != null) {
+            for (Object x : array) {
+                if (x != null) ++count;
+            }
         }
         return count;
     }
 
     static MethodHandle makePairwiseConvertByEditor(MethodHandle target, MethodType srcType,
                                                     boolean strict, boolean monobox) {
+        // In method types arguments start at index 0, while the LF
+        // editor have the MH receiver at position 0 - adjust appropriately.
+        final int MH_RECEIVER_OFFSET = 1;
         Object[] convSpecs = computeValueConversions(srcType, target.type(), strict, monobox);
         int convCount = countNonNull(convSpecs);
         if (convCount == 0)
@@ -272,27 +278,52 @@
         MethodType basicSrcType = srcType.basicType();
         MethodType midType = target.type().basicType();
         BoundMethodHandle mh = target.rebind();
-        // FIXME: Reduce number of bindings when there is more than one Class conversion.
-        // FIXME: Reduce number of bindings when there are repeated conversions.
-        for (int i = 0; i < convSpecs.length-1; i++) {
+
+        // Match each unique conversion to the positions at which it is to be applied
+        var convSpecMap = new HashMap<Object, int[]>(((4 * convCount) / 3) + 1);
+        for (int i = 0; i < convSpecs.length - MH_RECEIVER_OFFSET; i++) {
             Object convSpec = convSpecs[i];
-            if (convSpec == null)  continue;
+            if (convSpec == null) continue;
+            int[] positions = convSpecMap.get(convSpec);
+            if (positions == null) {
+                positions = new int[] { i + MH_RECEIVER_OFFSET };
+            } else {
+                positions = Arrays.copyOf(positions, positions.length + 1);
+                positions[positions.length - 1] = i + MH_RECEIVER_OFFSET;
+            }
+            convSpecMap.put(convSpec, positions);
+        }
+        for (var entry : convSpecMap.entrySet()) {
+            Object convSpec = entry.getKey();
+
             MethodHandle fn;
             if (convSpec instanceof Class) {
                 fn = getConstantHandle(MH_cast).bindTo(convSpec);
             } else {
                 fn = (MethodHandle) convSpec;
             }
-            Class<?> newType = basicSrcType.parameterType(i);
-            if (--convCount == 0)
+            int[] positions = entry.getValue();
+            Class<?> newType = basicSrcType.parameterType(positions[0] - MH_RECEIVER_OFFSET);
+            BasicType newBasicType = BasicType.basicType(newType);
+            convCount -= positions.length;
+            if (convCount == 0) {
                 midType = srcType;
-            else
-                midType = midType.changeParameterType(i, newType);
-            LambdaForm form2 = mh.editor().filterArgumentForm(1+i, BasicType.basicType(newType));
+            } else {
+                Class<?>[] ptypes = midType.ptypes().clone();
+                for (int pos : positions) {
+                    ptypes[pos - 1] = newType;
+                }
+                midType = MethodType.makeImpl(midType.rtype(), ptypes, true);
+            }
+            LambdaForm form2;
+            if (positions.length > 1) {
+                form2 = mh.editor().filterRepeatedArgumentForm(newBasicType, positions);
+            } else {
+                form2 = mh.editor().filterArgumentForm(positions[0], newBasicType);
+            }
             mh = mh.copyWithExtendL(midType, form2, fn);
-            mh = mh.rebind();
         }
-        Object convSpec = convSpecs[convSpecs.length-1];
+        Object convSpec = convSpecs[convSpecs.length - 1];
         if (convSpec != null) {
             MethodHandle fn;
             if (convSpec instanceof Class) {
@@ -320,98 +351,18 @@
         return mh;
     }
 
-    static MethodHandle makePairwiseConvertIndirect(MethodHandle target, MethodType srcType,
-                                                    boolean strict, boolean monobox) {
-        assert(target.type().parameterCount() == srcType.parameterCount());
-        // Calculate extra arguments (temporaries) required in the names array.
-        Object[] convSpecs = computeValueConversions(srcType, target.type(), strict, monobox);
-        final int INARG_COUNT = srcType.parameterCount();
-        int convCount = countNonNull(convSpecs);
-        boolean retConv = (convSpecs[INARG_COUNT] != null);
-        boolean retVoid = srcType.returnType() == void.class;
-        if (retConv && retVoid) {
-            convCount -= 1;
-            retConv = false;
-        }
-
-        final int IN_MH         = 0;
-        final int INARG_BASE    = 1;
-        final int INARG_LIMIT   = INARG_BASE + INARG_COUNT;
-        final int NAME_LIMIT    = INARG_LIMIT + convCount + 1;
-        final int RETURN_CONV   = (!retConv ? -1         : NAME_LIMIT - 1);
-        final int OUT_CALL      = (!retConv ? NAME_LIMIT : RETURN_CONV) - 1;
-        final int RESULT        = (retVoid ? -1 : NAME_LIMIT - 1);
-
-        // Now build a LambdaForm.
-        MethodType lambdaType = srcType.basicType().invokerType();
-        Name[] names = arguments(NAME_LIMIT - INARG_LIMIT, lambdaType);
-
-        // Collect the arguments to the outgoing call, maybe with conversions:
-        final int OUTARG_BASE = 0;  // target MH is Name.function, name Name.arguments[0]
-        Object[] outArgs = new Object[OUTARG_BASE + INARG_COUNT];
-
-        int nameCursor = INARG_LIMIT;
-        for (int i = 0; i < INARG_COUNT; i++) {
-            Object convSpec = convSpecs[i];
-            if (convSpec == null) {
-                // do nothing: difference is trivial
-                outArgs[OUTARG_BASE + i] = names[INARG_BASE + i];
-                continue;
-            }
-
-            Name conv;
-            if (convSpec instanceof Class) {
-                Class<?> convClass = (Class<?>) convSpec;
-                conv = new Name(getConstantHandle(MH_cast), convClass, names[INARG_BASE + i]);
-            } else {
-                MethodHandle fn = (MethodHandle) convSpec;
-                conv = new Name(fn, names[INARG_BASE + i]);
-            }
-            assert(names[nameCursor] == null);
-            names[nameCursor++] = conv;
-            assert(outArgs[OUTARG_BASE + i] == null);
-            outArgs[OUTARG_BASE + i] = conv;
-        }
-
-        // Build argument array for the call.
-        assert(nameCursor == OUT_CALL);
-        names[OUT_CALL] = new Name(target, outArgs);
-
-        Object convSpec = convSpecs[INARG_COUNT];
-        if (!retConv) {
-            assert(OUT_CALL == names.length-1);
-        } else {
-            Name conv;
-            if (convSpec == void.class) {
-                conv = new Name(LambdaForm.constantZero(BasicType.basicType(srcType.returnType())));
-            } else if (convSpec instanceof Class) {
-                Class<?> convClass = (Class<?>) convSpec;
-                conv = new Name(getConstantHandle(MH_cast), convClass, names[OUT_CALL]);
-            } else {
-                MethodHandle fn = (MethodHandle) convSpec;
-                if (fn.type().parameterCount() == 0)
-                    conv = new Name(fn);  // don't pass retval to void conversion
-                else
-                    conv = new Name(fn, names[OUT_CALL]);
-            }
-            assert(names[RETURN_CONV] == null);
-            names[RETURN_CONV] = conv;
-            assert(RETURN_CONV == names.length-1);
-        }
-
-        LambdaForm form = new LambdaForm(lambdaType.parameterCount(), names, RESULT, Kind.CONVERT);
-        return SimpleMethodHandle.make(srcType, form);
-    }
-
     static Object[] computeValueConversions(MethodType srcType, MethodType dstType,
                                             boolean strict, boolean monobox) {
         final int INARG_COUNT = srcType.parameterCount();
-        Object[] convSpecs = new Object[INARG_COUNT+1];
+        Object[] convSpecs = null;
         for (int i = 0; i <= INARG_COUNT; i++) {
             boolean isRet = (i == INARG_COUNT);
             Class<?> src = isRet ? dstType.returnType() : srcType.parameterType(i);
             Class<?> dst = isRet ? srcType.returnType() : dstType.parameterType(i);
             if (!VerifyType.isNullConversion(src, dst, /*keepInterfaces=*/ strict)) {
+                if (convSpecs == null) {
+                    convSpecs = new Object[INARG_COUNT + 1];
+                }
                 convSpecs[i] = valueConversion(src, dst, strict, monobox);
             }
         }
--- a/src/java.base/share/classes/java/lang/invoke/MethodHandles.java	Sat Nov 10 20:47:28 2018 +0100
+++ b/src/java.base/share/classes/java/lang/invoke/MethodHandles.java	Sun Nov 11 21:24:46 2018 +0100
@@ -3864,18 +3864,63 @@
      */
     public static
     MethodHandle filterArguments(MethodHandle target, int pos, MethodHandle... filters) {
+        // In method types arguments start at index 0, while the LF
+        // editor have the MH receiver at position 0 - adjust appropriately.
+        final int MH_RECEIVER_OFFSET = 1;
         filterArgumentsCheckArity(target, pos, filters);
         MethodHandle adapter = target;
+
+        // keep track of currently matched filters, as to optimize repeated filters
+        int index = 0;
+        int[] positions = new int[filters.length];
+        MethodHandle filter = null;
+
         // process filters in reverse order so that the invocation of
         // the resulting adapter will invoke the filters in left-to-right order
         for (int i = filters.length - 1; i >= 0; --i) {
-            MethodHandle filter = filters[i];
-            if (filter == null)  continue;  // ignore null elements of filters
-            adapter = filterArgument(adapter, pos + i, filter);
+            MethodHandle newFilter = filters[i];
+            if (newFilter == null) continue;  // ignore null elements of filters
+
+            // flush changes on update
+            if (filter != newFilter) {
+                if (filter != null) {
+                    if (index > 1) {
+                        adapter = filterRepeatedArgument(adapter, filter, Arrays.copyOf(positions, index));
+                    } else {
+                        adapter = filterArgument(adapter, positions[0] - 1, filter);
+                    }
+                }
+                filter = newFilter;
+                index = 0;
+            }
+
+            filterArgumentChecks(target, pos + i, newFilter);
+            positions[index++] = pos + i + MH_RECEIVER_OFFSET;
+        }
+        if (index > 1) {
+            adapter = filterRepeatedArgument(adapter, filter, Arrays.copyOf(positions, index));
+        } else if (index == 1) {
+            adapter = filterArgument(adapter, positions[0] - 1, filter);
         }
         return adapter;
     }
 
+    private static MethodHandle filterRepeatedArgument(MethodHandle adapter, MethodHandle filter, int[] positions) {
+        MethodType targetType = adapter.type();
+        MethodType filterType = filter.type();
+        BoundMethodHandle result = adapter.rebind();
+        Class<?> newParamType = filterType.parameterType(0);
+
+        Class<?>[] ptypes = targetType.ptypes().clone();
+        for (int pos : positions) {
+            ptypes[pos - 1] = newParamType;
+        }
+        MethodType newType = MethodType.makeImpl(targetType.rtype(), ptypes, true);
+
+        LambdaForm lform = result.editor().filterRepeatedArgumentForm(BasicType.basicType(newParamType), positions);
+        return result.copyWithExtendL(newType, lform, filter);
+    }
+
     /*non-public*/ static
     MethodHandle filterArgument(MethodHandle target, int pos, MethodHandle filter) {
         filterArgumentChecks(target, pos, filter);