8001110: method handles should have a collectArguments transform, generalizing asCollector
authorjrose
Sat, 05 Oct 2013 05:30:39 -0700
changeset 20533 bee974bc42ac
parent 20532 50fba462caa5
child 20534 da86f7904e6d
8001110: method handles should have a collectArguments transform, generalizing asCollector Summary: promote an existing private method; make unit tests on all argument positions to arity 10 with mixed types Reviewed-by: twisti, vlivanov
jdk/src/share/classes/java/lang/invoke/MethodHandles.java
jdk/src/share/classes/sun/invoke/util/ValueConversions.java
jdk/test/java/lang/invoke/JavaDocExamplesTest.java
jdk/test/java/lang/invoke/MethodHandlesTest.java
--- a/jdk/src/share/classes/java/lang/invoke/MethodHandles.java	Sat Oct 05 05:30:39 2013 -0700
+++ b/jdk/src/share/classes/java/lang/invoke/MethodHandles.java	Sat Oct 05 05:30:39 2013 -0700
@@ -2216,15 +2216,120 @@
         return MethodHandleImpl.makeCollectArguments(target, filter, pos, false);
     }
 
-    // FIXME: Make this public in M1.
-    /*non-public*/ static
-    MethodHandle collectArguments(MethodHandle target, int pos, MethodHandle collector) {
+    /**
+     * Adapts a target method handle by pre-processing
+     * a sub-sequence of its arguments with a filter (another method handle).
+     * The pre-processed arguments are replaced by the result (if any) of the
+     * filter function.
+     * The target is then called on the modified (usually shortened) argument list.
+     * <p>
+     * If the filter returns a value, the target must accept that value as
+     * its argument in position {@code pos}, preceded and/or followed by
+     * any arguments not passed to the filter.
+     * If the filter returns void, the target must accept all arguments
+     * not passed to the filter.
+     * No arguments are reordered, and a result returned from the filter
+     * replaces (in order) the whole subsequence of arguments originally
+     * passed to the adapter.
+     * <p>
+     * The argument types (if any) of the filter
+     * replace zero or one argument types of the target, at position {@code pos},
+     * in the resulting adapted method handle.
+     * The return type of the filter (if any) must be identical to the
+     * argument type of the target at position {@code pos}, and that target argument
+     * is supplied by the return value of the filter.
+     * <p>
+     * In all cases, {@code pos} must be greater than or equal to zero, and
+     * {@code pos} must also be less than or equal to the target's arity.
+     * <p><b>Example:</b>
+     * <p><blockquote><pre>
+import static java.lang.invoke.MethodHandles.*;
+import static java.lang.invoke.MethodType.*;
+...
+MethodHandle deepToString = publicLookup()
+  .findStatic(Arrays.class, "deepToString", methodType(String.class, Object[].class));
+
+MethodHandle ts1 = deepToString.asCollector(String[].class, 1);
+assertEquals("[strange]", (String) ts1.invokeExact("strange"));
+
+MethodHandle ts2 = deepToString.asCollector(String[].class, 2);
+assertEquals("[up, down]", (String) ts2.invokeExact("up", "down"));
+
+MethodHandle ts3 = deepToString.asCollector(String[].class, 3);
+MethodHandle ts3_ts2 = collectArguments(ts3, 1, ts2);
+assertEquals("[top, [up, down], strange]",
+             (String) ts3_ts2.invokeExact("top", "up", "down", "strange"));
+
+MethodHandle ts3_ts2_ts1 = collectArguments(ts3_ts2, 3, ts1);
+assertEquals("[top, [up, down], [strange]]",
+             (String) ts3_ts2_ts1.invokeExact("top", "up", "down", "strange"));
+
+MethodHandle ts3_ts2_ts3 = collectArguments(ts3_ts2, 1, ts3);
+assertEquals("[top, [[up, down, strange], charm], bottom]",
+             (String) ts3_ts2_ts3.invokeExact("top", "up", "down", "strange", "charm", "bottom"));
+     * </pre></blockquote>
+     * <p> Here is pseudocode for the resulting adapter:
+     * <blockquote><pre>
+     * T target(A...,V,C...);
+     * V filter(B...);
+     * T adapter(A... a,B... b,C... c) {
+     *   V v = filter(b...);
+     *   return target(a...,v,c...);
+     * }
+     * // and if the filter has no arguments:
+     * T target2(A...,V,C...);
+     * V filter2();
+     * T adapter2(A... a,C... c) {
+     *   V v = filter2();
+     *   return target2(a...,v,c...);
+     * }
+     * // and if the filter has a void return:
+     * T target3(A...,C...);
+     * void filter3(B...);
+     * void adapter3(A... a,B... b,C... c) {
+     *   filter3(b...);
+     *   return target3(a...,c...);
+     * }
+     * </pre></blockquote>
+     * <p>
+     * A collection adapter {@code collectArguments(mh, 0, coll)} is equivalent to
+     * one which first "folds" the affected arguments, and then drops them, in separate
+     * steps as follows:
+     * <blockquote><pre>{@code
+     * mh = MethodHandles.dropArguments(mh, 1, coll.type().parameterList()); //step 2
+     * mh = MethodHandles.foldArguments(mh, coll); //step 1
+     * }</pre></blockquote>
+     * If the target method handle consumes no arguments besides than the result
+     * (if any) of the filter {@code coll}, then {@code collectArguments(mh, 0, coll)}
+     * is equivalent to {@code filterReturnValue(coll, mh)}.
+     * If the filter method handle {@code coll} consumes one argument and produces
+     * a non-void result, then {@code collectArguments(mh, N, coll)}
+     * is equivalent to {@code filterArguments(mh, N, coll)}.
+     * Other equivalences are possible but would require argument permutation.
+     *
+     * @param target the method handle to invoke after filtering the subsequence of arguments
+     * @param pos the position of the first adapter argument to pass to the filter,
+     *            and/or the target argument which receives the result of the filter
+     * @param filter method handle to call on the subsequence of arguments
+     * @return method handle which incorporates the specified argument subsequence filtering logic
+     * @throws NullPointerException if either argument is null
+     * @throws IllegalArgumentException if the return type of {@code filter}
+     *          is non-void and is not the same as the {@code pos} argument of the target,
+     *          or if {@code pos} is not between 0 and the target's arity, inclusive,
+     *          or if the resulting method handle's type would have
+     *          <a href="MethodHandle.html#maxarity">too many parameters</a>
+     * @see MethodHandles#foldArguments
+     * @see MethodHandles#filterArguments
+     * @see MethodHandles#filterReturnValue
+     */
+    public static
+    MethodHandle collectArguments(MethodHandle target, int pos, MethodHandle filter) {
         MethodType targetType = target.type();
-        MethodType filterType = collector.type();
+        MethodType filterType = filter.type();
         if (filterType.returnType() != void.class &&
             filterType.returnType() != targetType.parameterType(pos))
             throw newIllegalArgumentException("target and filter types do not match", targetType, filterType);
-        return MethodHandleImpl.makeCollectArguments(target, collector, pos, false);
+        return MethodHandleImpl.makeCollectArguments(target, filter, pos, false);
     }
 
     /**
--- a/jdk/src/share/classes/sun/invoke/util/ValueConversions.java	Sat Oct 05 05:30:39 2013 -0700
+++ b/jdk/src/share/classes/sun/invoke/util/ValueConversions.java	Sat Oct 05 05:30:39 2013 -0700
@@ -502,51 +502,6 @@
         }
     }
 
-    static MethodHandle collectArguments(MethodHandle mh, int pos, MethodHandle collector) {
-        // FIXME: API needs public MHs.collectArguments.
-        // Should be:
-        //   return MethodHandles.collectArguments(mh, 0, collector);
-        // The rest of this code is a workaround for not having that API.
-        if (COLLECT_ARGUMENTS != null) {
-            try {
-                return (MethodHandle)
-                    COLLECT_ARGUMENTS.invokeExact(mh, pos, collector);
-            } catch (Throwable ex) {
-                if (ex instanceof RuntimeException)
-                    throw (RuntimeException) ex;
-                if (ex instanceof Error)
-                    throw (Error) ex;
-                throw new Error(ex.getMessage(), ex);
-            }
-        }
-        // Emulate MHs.collectArguments using fold + drop.
-        // This is slightly inefficient.
-        // More seriously, it can put a MH over the 255-argument limit.
-        mh = MethodHandles.dropArguments(mh, 1, collector.type().parameterList());
-        mh = MethodHandles.foldArguments(mh, collector);
-        return mh;
-    }
-    private static final MethodHandle COLLECT_ARGUMENTS;
-    static {
-        MethodHandle mh = null;
-        try {
-            final java.lang.reflect.Method m = MethodHandles.class
-                .getDeclaredMethod("collectArguments",
-                    MethodHandle.class, int.class, MethodHandle.class);
-            AccessController.doPrivileged(new PrivilegedAction<Void>() {
-                    @Override
-                    public Void run() {
-                        m.setAccessible(true);
-                        return null;
-                    }
-                });
-            mh = IMPL_LOOKUP.unreflect(m);
-        } catch (ReflectiveOperationException ex) {
-            throw newInternalError(ex);
-        }
-        COLLECT_ARGUMENTS = mh;
-    }
-
     private static final EnumMap<Wrapper, MethodHandle>[] WRAPPER_CASTS
             = newWrapperCaches(1);
 
@@ -1050,12 +1005,12 @@
             if (mh == ARRAY_IDENTITY)
                 mh = rightFiller;
             else
-                mh = collectArguments(mh, 0, rightFiller);
+                mh = MethodHandles.collectArguments(mh, 0, rightFiller);
         }
         if (mh == ARRAY_IDENTITY)
             mh = leftCollector;
         else
-            mh = collectArguments(mh, 0, leftCollector);
+            mh = MethodHandles.collectArguments(mh, 0, leftCollector);
         return mh;
     }
 
@@ -1101,7 +1056,7 @@
         if (midLen == LEFT_ARGS)
             return rightFill;
         else
-            return collectArguments(rightFill, 0, midFill);
+            return MethodHandles.collectArguments(rightFill, 0, midFill);
     }
 
     // Type-polymorphic version of varargs maker.
--- a/jdk/test/java/lang/invoke/JavaDocExamplesTest.java	Sat Oct 05 05:30:39 2013 -0700
+++ b/jdk/test/java/lang/invoke/JavaDocExamplesTest.java	Sat Oct 05 05:30:39 2013 -0700
@@ -281,6 +281,28 @@
             }}
     }
 
+    @Test public void testCollectArguments() throws Throwable {
+        {{
+{} /// JAVADOC
+MethodHandle deepToString = publicLookup()
+  .findStatic(Arrays.class, "deepToString", methodType(String.class, Object[].class));
+MethodHandle ts1 = deepToString.asCollector(String[].class, 1);
+assertEquals("[strange]", (String) ts1.invokeExact("strange"));
+MethodHandle ts2 = deepToString.asCollector(String[].class, 2);
+assertEquals("[up, down]", (String) ts2.invokeExact("up", "down"));
+MethodHandle ts3 = deepToString.asCollector(String[].class, 3);
+MethodHandle ts3_ts2 = collectArguments(ts3, 1, ts2);
+assertEquals("[top, [up, down], strange]",
+             (String) ts3_ts2.invokeExact("top", "up", "down", "strange"));
+MethodHandle ts3_ts2_ts1 = collectArguments(ts3_ts2, 3, ts1);
+assertEquals("[top, [up, down], [strange]]",
+             (String) ts3_ts2_ts1.invokeExact("top", "up", "down", "strange"));
+MethodHandle ts3_ts2_ts3 = collectArguments(ts3_ts2, 1, ts3);
+assertEquals("[top, [[up, down, strange], charm], bottom]",
+             (String) ts3_ts2_ts3.invokeExact("top", "up", "down", "strange", "charm", "bottom"));
+            }}
+    }
+
     @Test public void testFoldArguments() throws Throwable {
         {{
 {} /// JAVADOC
--- a/jdk/test/java/lang/invoke/MethodHandlesTest.java	Sat Oct 05 05:30:39 2013 -0700
+++ b/jdk/test/java/lang/invoke/MethodHandlesTest.java	Sat Oct 05 05:30:39 2013 -0700
@@ -277,6 +277,9 @@
             args[i] = randomArg(param);
         return args;
     }
+    static Object[] randomArgs(List<Class<?>> params) {
+        return randomArgs(params.toArray(new Class<?>[params.size()]));
+    }
 
     @SafeVarargs @SuppressWarnings("varargs")
     static <T, E extends T> T[] array(Class<T[]> atype, E... a) {
@@ -347,6 +350,11 @@
         }
         return list.asType(listType);
     }
+    /** Variation of varargsList, but with the given ptypes and rtype. */
+    static MethodHandle varargsList(List<Class<?>> ptypes, Class<?> rtype) {
+        MethodHandle list = varargsList(ptypes.size(), rtype);
+        return list.asType(MethodType.methodType(rtype, ptypes));
+    }
     private static MethodHandle LIST_TO_STRING, LIST_TO_INT;
     private static String listToString(List<?> x) { return x.toString(); }
     private static int listToInt(List<?> x) { return x.toString().hashCode(); }
@@ -1833,24 +1841,24 @@
     }
 
     @Test  // SLOW
-    public void testCollectArguments() throws Throwable {
+    public void testAsCollector() throws Throwable {
         if (CAN_SKIP_WORKING)  return;
-        startTest("collectArguments");
+        startTest("asCollector");
         for (Class<?> argType : new Class<?>[]{Object.class, Integer.class, int.class}) {
             if (verbosity >= 3)
-                System.out.println("collectArguments "+argType);
+                System.out.println("asCollector "+argType);
             for (int nargs = 0; nargs < 50; nargs++) {
                 if (CAN_TEST_LIGHTLY && nargs > 11)  break;
                 for (int pos = 0; pos <= nargs; pos++) {
                     if (CAN_TEST_LIGHTLY && pos > 2 && pos < nargs-2)  continue;
                     if (nargs > 10 && pos > 4 && pos < nargs-4 && pos % 10 != 3)
                         continue;
-                    testCollectArguments(argType, pos, nargs);
+                    testAsCollector(argType, pos, nargs);
                 }
             }
         }
     }
-    public void testCollectArguments(Class<?> argType, int pos, int nargs) throws Throwable {
+    public void testAsCollector(Class<?> argType, int pos, int nargs) throws Throwable {
         countTest();
         // fake up a MH with the same type as the desired adapter:
         MethodHandle fake = varargsArray(nargs);
@@ -1997,37 +2005,108 @@
     }
 
     @Test
+    public void testCollectArguments() throws Throwable {
+        if (CAN_SKIP_WORKING)  return;
+        startTest("collectArguments");
+        testFoldOrCollectArguments(true);
+    }
+
+    @Test
     public void testFoldArguments() throws Throwable {
         if (CAN_SKIP_WORKING)  return;
         startTest("foldArguments");
-        for (int nargs = 0; nargs <= 4; nargs++) {
-            for (int fold = 0; fold <= nargs; fold++) {
-                for (int pos = 0; pos <= nargs; pos++) {
-                    testFoldArguments(nargs, pos, fold);
+        testFoldOrCollectArguments(false);
+    }
+
+    void testFoldOrCollectArguments(boolean isCollect) throws Throwable {
+        for (Class<?> lastType : new Class<?>[]{ Object.class, String.class, int.class }) {
+            for (Class<?> collectType : new Class<?>[]{ Object.class, String.class, int.class, void.class }) {
+                int maxArity = 10;
+                if (collectType != String.class)  maxArity = 5;
+                if (lastType != Object.class)  maxArity = 4;
+                for (int nargs = 0; nargs <= maxArity; nargs++) {
+                    ArrayList<Class<?>> argTypes = new ArrayList<>(Collections.nCopies(nargs, Object.class));
+                    int maxMix = 20;
+                    if (collectType != Object.class)  maxMix = 0;
+                    Map<Object,Integer> argTypesSeen = new HashMap<>();
+                    for (int mix = 0; mix <= maxMix; mix++) {
+                        if (!mixArgs(argTypes, mix, argTypesSeen))  continue;
+                        for (int collect = 0; collect <= nargs; collect++) {
+                            for (int pos = 0; pos <= nargs - collect; pos++) {
+                                testFoldOrCollectArguments(argTypes, pos, collect, collectType, lastType, isCollect);
+                            }
+                        }
+                    }
                 }
             }
         }
     }
 
-    void testFoldArguments(int nargs, int pos, int fold) throws Throwable {
-        if (pos != 0)  return;  // can fold only at pos=0 for now
+    boolean mixArgs(List<Class<?>> argTypes, int mix, Map<Object,Integer> argTypesSeen) {
+        assert(mix >= 0);
+        if (mix == 0)  return true;  // no change
+        if ((mix >>> argTypes.size()) != 0)  return false;
+        for (int i = 0; i < argTypes.size(); i++) {
+            if (i >= 31)  break;
+            boolean bit = (mix & (1 << i)) != 0;
+            if (bit) {
+                Class<?> type = argTypes.get(i);
+                if (type == Object.class)
+                    type = String.class;
+                else if (type == String.class)
+                    type = int.class;
+                else
+                    type = Object.class;
+                argTypes.set(i, type);
+            }
+        }
+        Integer prev = argTypesSeen.put(new ArrayList<>(argTypes), mix);
+        if (prev != null) {
+            if (verbosity >= 4)  System.out.println("mix "+prev+" repeated "+mix+": "+argTypes);
+            return false;
+        }
+        if (verbosity >= 3)  System.out.println("mix "+mix+" = "+argTypes);
+        return true;
+    }
+
+    void testFoldOrCollectArguments(List<Class<?>> argTypes,  // argument types minus the inserted combineType
+                                    int pos, int fold, // position and length of the folded arguments
+                                    Class<?> combineType, // type returned from the combiner
+                                    Class<?> lastType,  // type returned from the target
+                                    boolean isCollect) throws Throwable {
+        int nargs = argTypes.size();
+        if (pos != 0 && !isCollect)  return;  // can fold only at pos=0 for now
         countTest();
-        MethodHandle target = varargsList(1 + nargs);
-        MethodHandle combine = varargsList(fold).asType(MethodType.genericMethodType(fold));
-        List<Object> argsToPass = Arrays.asList(randomArgs(nargs, Object.class));
+        List<Class<?>> combineArgTypes = argTypes.subList(pos, pos + fold);
+        List<Class<?>> targetArgTypes = new ArrayList<>(argTypes);
+        if (isCollect)  // does targret see arg[pos..pos+cc-1]?
+            targetArgTypes.subList(pos, pos + fold).clear();
+        if (combineType != void.class)
+            targetArgTypes.add(pos, combineType);
+        MethodHandle target = varargsList(targetArgTypes, lastType);
+        MethodHandle combine = varargsList(combineArgTypes, combineType);
+        List<Object> argsToPass = Arrays.asList(randomArgs(argTypes));
         if (verbosity >= 3)
-            System.out.println("fold "+target+" with "+combine);
-        MethodHandle target2 = MethodHandles.foldArguments(target, combine);
+            System.out.println((isCollect ? "collect" : "fold")+" "+target+" with "+combine);
+        MethodHandle target2;
+        if (isCollect)
+            target2 = MethodHandles.collectArguments(target, pos, combine);
+        else
+            target2 = MethodHandles.foldArguments(target, combine);
         // Simulate expected effect of combiner on arglist:
-        List<Object> expected = new ArrayList<>(argsToPass);
-        List<Object> argsToFold = expected.subList(pos, pos + fold);
+        List<Object> expectedList = new ArrayList<>(argsToPass);
+        List<Object> argsToFold = expectedList.subList(pos, pos + fold);
         if (verbosity >= 3)
-            System.out.println("fold: "+argsToFold+" into "+target2);
+            System.out.println((isCollect ? "collect" : "fold")+": "+argsToFold+" into "+target2);
         Object foldedArgs = combine.invokeWithArguments(argsToFold);
-        argsToFold.add(0, foldedArgs);
+        if (isCollect)
+            argsToFold.clear();
+        if (combineType != void.class)
+            argsToFold.add(0, foldedArgs);
         Object result = target2.invokeWithArguments(argsToPass);
         if (verbosity >= 3)
             System.out.println("result: "+result);
+        Object expected = target.invokeWithArguments(expectedList);
         if (!expected.equals(result))
             System.out.println("*** fail at n/p/f = "+nargs+"/"+pos+"/"+fold+": "+argsToPass+" => "+result+" != "+expected);
         assertEquals(expected, result);