src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/spi/SwitchFoldable.java
changeset 57537 ecc6e394475f
equal deleted inserted replaced
57536:67cce1b84a9a 57537:ecc6e394475f
       
     1 /*
       
     2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
       
     3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
       
     4  *
       
     5  * This code is free software; you can redistribute it and/or modify it
       
     6  * under the terms of the GNU General Public License version 2 only, as
       
     7  * published by the Free Software Foundation.
       
     8  *
       
     9  * This code is distributed in the hope that it will be useful, but WITHOUT
       
    10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
       
    11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
       
    12  * version 2 for more details (a copy is included in the LICENSE file that
       
    13  * accompanied this code).
       
    14  *
       
    15  * You should have received a copy of the GNU General Public License version
       
    16  * 2 along with this work; if not, write to the Free Software Foundation,
       
    17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
       
    18  *
       
    19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
       
    20  * or visit www.oracle.com if you need additional information or have any
       
    21  * questions.
       
    22  */
       
    23 
       
    24 
       
    25 
       
    26 package org.graalvm.compiler.nodes.spi;
       
    27 
       
    28 import java.util.ArrayList;
       
    29 import java.util.Comparator;
       
    30 import java.util.List;
       
    31 
       
    32 import jdk.internal.vm.compiler.collections.EconomicMap;
       
    33 import jdk.internal.vm.compiler.collections.Equivalence;
       
    34 import org.graalvm.compiler.core.common.SuppressFBWarnings;
       
    35 import org.graalvm.compiler.core.common.type.IntegerStamp;
       
    36 import org.graalvm.compiler.core.common.type.PrimitiveStamp;
       
    37 import org.graalvm.compiler.core.common.type.Stamp;
       
    38 import org.graalvm.compiler.graph.Node;
       
    39 import org.graalvm.compiler.graph.spi.SimplifierTool;
       
    40 import org.graalvm.compiler.nodes.AbstractBeginNode;
       
    41 import org.graalvm.compiler.nodes.BeginNode;
       
    42 import org.graalvm.compiler.nodes.FixedNode;
       
    43 import org.graalvm.compiler.nodes.LogicNode;
       
    44 import org.graalvm.compiler.nodes.NodeView;
       
    45 import org.graalvm.compiler.nodes.StructuredGraph;
       
    46 import org.graalvm.compiler.nodes.ValueNode;
       
    47 import org.graalvm.compiler.nodes.ValueNodeInterface;
       
    48 import org.graalvm.compiler.nodes.calc.IntegerEqualsNode;
       
    49 import org.graalvm.compiler.nodes.calc.SignExtendNode;
       
    50 import org.graalvm.compiler.nodes.extended.IntegerSwitchNode;
       
    51 import org.graalvm.compiler.nodes.util.GraphUtil;
       
    52 
       
    53 /**
       
    54  * Nodes that implement this interface can be collapsed to a single IntegerSwitch when they are seen
       
    55  * in a cascade.
       
    56  */
       
    57 @SuppressFBWarnings(value = {"UCF"}, justification = "javac spawns useless control flow in static initializer when using assert(asNode().isAlive())")
       
    58 public interface SwitchFoldable extends ValueNodeInterface {
       
    59     Comparator<KeyData> SORTER = Comparator.comparingInt((KeyData k) -> k.key);
       
    60 
       
    61     /**
       
    62      * Returns the direct successor in the branch to check for SwitchFoldability.
       
    63      */
       
    64     Node getNextSwitchFoldableBranch();
       
    65 
       
    66     /**
       
    67      * Returns the value that will be used as the switch input. This value should be an int.
       
    68      */
       
    69     ValueNode switchValue();
       
    70 
       
    71     /**
       
    72      * Returns the branch that will close this switch folding, assuming this is called on the lowest
       
    73      * node of the cascade.
       
    74      */
       
    75     AbstractBeginNode getDefault();
       
    76 
       
    77     /**
       
    78      * Determines whether the node should be folded in the current folding attempt.
       
    79      *
       
    80      * @param switchValue the value of the switch that will spawn through this folding attempt.
       
    81      * @return true if this node should be folded in the current folding attempt, false otherwise.
       
    82      * @see SwitchFoldable#maybeIsInSwitch(LogicNode)
       
    83      * @see SwitchFoldable#sameSwitchValue(LogicNode, ValueNode)
       
    84      */
       
    85     boolean isInSwitch(ValueNode switchValue);
       
    86 
       
    87     /**
       
    88      * Removes the successors of this node, while keeping it linked to the rest of the cascade.
       
    89      */
       
    90     void cutOffCascadeNode();
       
    91 
       
    92     /**
       
    93      * Completely removes all successors from this node.
       
    94      */
       
    95     void cutOffLowestCascadeNode();
       
    96 
       
    97     /**
       
    98      * Returns the value of the i-th key of this node.
       
    99      */
       
   100     int intKeyAt(int i);
       
   101 
       
   102     /**
       
   103      * Returns the probability of seeing the i-th key of this node.
       
   104      */
       
   105     double keyProbability(int i);
       
   106 
       
   107     /**
       
   108      * Returns the branch to follow when seeing the i-th key of this node.
       
   109      */
       
   110     AbstractBeginNode keySuccessor(int i);
       
   111 
       
   112     /**
       
   113      * Returns the probability of going to the default branch.
       
   114      */
       
   115     double defaultProbability();
       
   116 
       
   117     /**
       
   118      * @return The number of keys the SwitchFoldable node will try to add.
       
   119      */
       
   120     default int keyCount() {
       
   121         return 1;
       
   122     }
       
   123 
       
   124     /**
       
   125      * Should be overridden if getDefault() has side effects.
       
   126      */
       
   127     default boolean isDefaultSuccessor(AbstractBeginNode successor) {
       
   128         return successor == getDefault();
       
   129     }
       
   130 
       
   131     /**
       
   132      * Heuristics that tries to determine whether or not a foldable node was profiled.
       
   133      */
       
   134     default boolean isNonInitializedProfile() {
       
   135         return false;
       
   136     }
       
   137 
       
   138     static boolean maybeIsInSwitch(LogicNode condition) {
       
   139         return condition instanceof IntegerEqualsNode && ((IntegerEqualsNode) condition).getY().isJavaConstant();
       
   140     }
       
   141 
       
   142     static boolean sameSwitchValue(LogicNode condition, ValueNode switchValue) {
       
   143         return ((IntegerEqualsNode) condition).getX() == switchValue;
       
   144     }
       
   145 
       
   146     // Helper data structures
       
   147 
       
   148     class Helper {
       
   149         private Helper() {
       
   150         }
       
   151 
       
   152         private static boolean isDuplicateKey(int key, QuickQueryKeyData keyData) {
       
   153             return keyData.contains(key);
       
   154         }
       
   155 
       
   156         private static int duplicateIndex(AbstractBeginNode begin, QuickQueryList<AbstractBeginNode> successors) {
       
   157             return successors.indexOf(begin);
       
   158         }
       
   159 
       
   160         private static Node skipUpBegins(Node node) {
       
   161             Node result = node;
       
   162             while (result instanceof BeginNode && result.hasNoUsages()) {
       
   163                 result = result.predecessor();
       
   164             }
       
   165             return result;
       
   166         }
       
   167 
       
   168         private static Node skipDownBegins(Node node) {
       
   169             Node result = node;
       
   170             while (result instanceof BeginNode && result.hasNoUsages()) {
       
   171                 result = ((BeginNode) result).next();
       
   172             }
       
   173             return result;
       
   174         }
       
   175 
       
   176         private static SwitchFoldable getParentSwitchNode(SwitchFoldable node, ValueNode switchValue) {
       
   177             Node result = skipUpBegins(node.asNode().predecessor());
       
   178             if (result instanceof SwitchFoldable && ((SwitchFoldable) result).isInSwitch(switchValue)) {
       
   179                 return (SwitchFoldable) result;
       
   180             }
       
   181             return null;
       
   182         }
       
   183 
       
   184         private static SwitchFoldable getChildSwitchNode(SwitchFoldable node, ValueNode switchValue) {
       
   185             Node result = skipDownBegins(node.getNextSwitchFoldableBranch());
       
   186             if (result instanceof SwitchFoldable && ((SwitchFoldable) result).isInSwitch(switchValue)) {
       
   187                 return (SwitchFoldable) result;
       
   188             }
       
   189             return null;
       
   190         }
       
   191 
       
   192         private static int addDefault(SwitchFoldable node, QuickQueryList<AbstractBeginNode> successors) {
       
   193             AbstractBeginNode defaultBranch = node.getDefault();
       
   194             int index = successors.indexOf(defaultBranch);
       
   195             if (index == -1) {
       
   196                 index = successors.size();
       
   197                 successors.add(defaultBranch);
       
   198             }
       
   199             return index;
       
   200         }
       
   201 
       
   202         private static int countNonDeoptSuccessors(QuickQueryKeyData keyData) {
       
   203             int result = 0;
       
   204             for (KeyData key : keyData.list) {
       
   205                 if (key.keyProbability > 0.0d) {
       
   206                     result++;
       
   207                 }
       
   208             }
       
   209             return result;
       
   210         }
       
   211 
       
   212         /**
       
   213          * Updates the current state of the IntegerSwitch that will be spawned. That means:
       
   214          * <p>
       
   215          * - Checking for duplicate keys: add the duplicate key's branch to duplicates
       
   216          * <p>
       
   217          * - For branches of non-duplicate keys: add them to successors and update the keyData
       
   218          * accordingly
       
   219          * <p>
       
   220          * - Update the value of the cumulative probability, ie, multiply it by the probability of
       
   221          * taking the next branch (according to {@link SwitchFoldable#getNextSwitchFoldableBranch})
       
   222          * <p>
       
   223          * </p>
       
   224          *
       
   225          * @see QuickQueryList
       
   226          * @see QuickQueryKeyData
       
   227          */
       
   228         private static void updateSwitchData(SwitchFoldable node, QuickQueryKeyData keyData, QuickQueryList<AbstractBeginNode> newSuccessors, double[] cumulative, double[] totalProbabilities,
       
   229                         QuickQueryList<AbstractBeginNode> duplicates) {
       
   230             for (int i = 0; i < node.keyCount(); i++) {
       
   231                 int key = node.intKeyAt(i);
       
   232                 double keyProbability = cumulative[0] * node.keyProbability(i);
       
   233                 KeyData data;
       
   234                 AbstractBeginNode keySuccessor = node.keySuccessor(i);
       
   235                 if (isDuplicateKey(key, keyData)) {
       
   236                     // Key was already seen
       
   237                     data = keyData.fromKey(key);
       
   238                     if (data.keySuccessor != KeyData.KEY_UNKNOWN) {
       
   239                         // Unreachable key: kill it manually at the end
       
   240                         if (!newSuccessors.contains(keySuccessor) && !duplicates.contains(keySuccessor) && keySuccessor.isAlive()) {
       
   241                             // This might be a false alert, if one of the next keys points to it.
       
   242                             duplicates.add(keySuccessor);
       
   243                         }
       
   244                         continue;
       
   245                     }
       
   246                     /*
       
   247                      * A key might not be able to immediately link to its target, if it is shared
       
   248                      * with the default target. In that case, we will need to resolve the target at
       
   249                      * a later time, either by seeing this key going to a known target in later
       
   250                      * cascade nodes, or by linking it to the overall default target at the very end
       
   251                      * of the folding.
       
   252                      */
       
   253                 } else {
       
   254                     data = new KeyData(key, keyProbability, KeyData.KEY_UNKNOWN);
       
   255                     totalProbabilities[0] += keyProbability;
       
   256                     keyData.add(data);
       
   257                 }
       
   258                 if (keySuccessor.isUnregistered()) {
       
   259                     // Shortcut map check if uninitialized node.
       
   260                     data.keySuccessor = newSuccessors.size();
       
   261                     newSuccessors.addUnique(keySuccessor);
       
   262                 } else {
       
   263                     int pos = duplicateIndex(keySuccessor, newSuccessors);
       
   264                     if (pos != -1) {
       
   265                         // Target is already known
       
   266                         data.keySuccessor = pos;
       
   267                     } else if (!node.isDefaultSuccessor(keySuccessor)) {
       
   268                         data.keySuccessor = newSuccessors.size();
       
   269                         newSuccessors.add(keySuccessor);
       
   270                     }
       
   271                 }
       
   272             }
       
   273             cumulative[0] *= node.defaultProbability();
       
   274         }
       
   275     }
       
   276 
       
   277     final class KeyData {
       
   278         private static final int KEY_UNKNOWN = -2;
       
   279 
       
   280         private final int key;
       
   281         private final double keyProbability;
       
   282         private int keySuccessor;
       
   283 
       
   284         KeyData(int key, double keyProbability, int keySuccessor) {
       
   285             this.key = key;
       
   286             this.keyProbability = keyProbability;
       
   287             this.keySuccessor = keySuccessor;
       
   288         }
       
   289     }
       
   290 
       
   291     /**
       
   292      * Supports O(1) addition to the list, fast {@code contains} and {@code indexOf} queries
       
   293      * (usually O(1), worst case O(n)), and O(1) random access.
       
   294      */
       
   295     final class QuickQueryList<T> {
       
   296         private final List<T> list = new ArrayList<>();
       
   297         private final EconomicMap<T, Integer> map = EconomicMap.create(Equivalence.IDENTITY);
       
   298 
       
   299         private int indexOf(T begin) {
       
   300             return map.get(begin, -1);
       
   301         }
       
   302 
       
   303         private boolean contains(T o) {
       
   304             return map.containsKey(o);
       
   305         }
       
   306 
       
   307         @SuppressWarnings("unused")
       
   308         private T get(int index) {
       
   309             return list.get(index);
       
   310         }
       
   311 
       
   312         private boolean add(T item) {
       
   313             map.put(item, list.size());
       
   314             return list.add(item);
       
   315         }
       
   316 
       
   317         /**
       
   318          * Adds an object, known to be unique beforehand.
       
   319          */
       
   320         private void addUnique(T item) {
       
   321             list.add(item);
       
   322         }
       
   323 
       
   324         private int size() {
       
   325             return list.size();
       
   326         }
       
   327     }
       
   328 
       
   329     final class QuickQueryKeyData {
       
   330         private final List<KeyData> list = new ArrayList<>();
       
   331         private final EconomicMap<Integer, KeyData> map = EconomicMap.create();
       
   332 
       
   333         private void add(KeyData key) {
       
   334             assert !map.containsKey(key.key);
       
   335             list.add(key);
       
   336             map.put(key.key, key);
       
   337         }
       
   338 
       
   339         private boolean contains(int key) {
       
   340             return map.containsKey(key);
       
   341         }
       
   342 
       
   343         private KeyData get(int index) {
       
   344             return list.get(index);
       
   345         }
       
   346 
       
   347         private int size() {
       
   348             return list.size();
       
   349         }
       
   350 
       
   351         private KeyData fromKey(int key) {
       
   352             assert contains(key);
       
   353             return map.get(key);
       
   354         }
       
   355 
       
   356         private void sort() {
       
   357             list.sort(SORTER);
       
   358         }
       
   359 
       
   360     }
       
   361 
       
   362     /**
       
   363      * Collapses a cascade of foldables (IfNode, FixedGuard and IntegerSwitch) into a single switch.
       
   364      */
       
   365     default boolean switchTransformationOptimization(SimplifierTool tool) {
       
   366         ValueNode switchValue = switchValue();
       
   367         assert asNode().isAlive();
       
   368         if (switchValue == null || !isInSwitch(switchValue) || (Helper.getParentSwitchNode(this, switchValue) == null && Helper.getChildSwitchNode(this, switchValue) == null)) {
       
   369             // Don't bother trying if there is nothing to do.
       
   370             return false;
       
   371         }
       
   372         Stamp switchStamp = switchValue.stamp(NodeView.DEFAULT);
       
   373 
       
   374         // Abort if we do not have an int
       
   375         if (!(switchStamp instanceof IntegerStamp)) {
       
   376             return false;
       
   377         }
       
   378         if (PrimitiveStamp.getBits(switchStamp) > 32) {
       
   379             return false;
       
   380         }
       
   381 
       
   382         // PlaceHolder for cascade traversal.
       
   383         SwitchFoldable iteratingNode = this;
       
   384         SwitchFoldable topMostSwitchNode = this;
       
   385 
       
   386         // Find top-most foldable.
       
   387         while (iteratingNode != null) {
       
   388             topMostSwitchNode = iteratingNode;
       
   389             iteratingNode = Helper.getParentSwitchNode(iteratingNode, switchValue);
       
   390         }
       
   391         QuickQueryKeyData keyData = new QuickQueryKeyData();
       
   392         QuickQueryList<AbstractBeginNode> successors = new QuickQueryList<>();
       
   393         QuickQueryList<AbstractBeginNode> potentiallyUnreachable = new QuickQueryList<>();
       
   394         double[] cumulative = {1.0d};
       
   395         double[] totalProbability = {0.0d};
       
   396 
       
   397         iteratingNode = topMostSwitchNode;
       
   398         SwitchFoldable lowestSwitchNode = topMostSwitchNode;
       
   399 
       
   400         // If this stays true, we will need to spawn an uniform distribution.
       
   401         boolean uninitializedProfiles = true;
       
   402 
       
   403         // Go down the if cascade, collecting necessary data
       
   404         while (iteratingNode != null) {
       
   405             lowestSwitchNode = iteratingNode;
       
   406             Helper.updateSwitchData(iteratingNode, keyData, successors, cumulative, totalProbability, potentiallyUnreachable);
       
   407             if (!iteratingNode.isNonInitializedProfile()) {
       
   408                 uninitializedProfiles = false;
       
   409             }
       
   410             iteratingNode = Helper.getChildSwitchNode(iteratingNode, switchValue);
       
   411         }
       
   412 
       
   413         if (keyData.size() < 4 || lowestSwitchNode == topMostSwitchNode) {
       
   414             // Abort if it's not worth the hassle
       
   415             return false;
       
   416         }
       
   417 
       
   418         // At that point, we will commit the optimization.
       
   419         StructuredGraph graph = asNode().graph();
       
   420 
       
   421         // Sort the keys
       
   422         keyData.sort();
       
   423 
       
   424         /*
       
   425          * The total probability might be different than 1 if there was a duplicate key which was
       
   426          * erased by another branch whose probability was different (/ex: in the case where a method
       
   427          * constituted of only a switch is inlined after a guard for a particular value of that
       
   428          * switch). In that case, we need to re-normalize the probabilities. A more "correct" way
       
   429          * would be to only re-normalize the probabilities of the switch after the guard, but this
       
   430          * cannot be done without an additional overhead.
       
   431          */
       
   432         totalProbability[0] += cumulative[0];
       
   433         assert totalProbability[0] > 0.0d;
       
   434         double normalizationFactor = 1 / totalProbability[0];
       
   435 
       
   436         // Spawn the required data structures
       
   437         int newKeyCount = keyData.list.size();
       
   438         int[] keys = new int[newKeyCount];
       
   439         double[] keyProbabilities = new double[newKeyCount + 1];
       
   440         int[] keySuccessors = new int[newKeyCount + 1];
       
   441         int nonDeoptSuccessorCount = Helper.countNonDeoptSuccessors(keyData) + (cumulative[0] > 0.0d ? 1 : 0);
       
   442         double uniform = (uninitializedProfiles && nonDeoptSuccessorCount > 0 ? 1 / (double) nonDeoptSuccessorCount : 1.0d);
       
   443 
       
   444         // Add default
       
   445         keyProbabilities[newKeyCount] = uninitializedProfiles && cumulative[0] > 0.0d ? uniform : normalizationFactor * cumulative[0];
       
   446         keySuccessors[newKeyCount] = Helper.addDefault(lowestSwitchNode, successors);
       
   447 
       
   448         // Add branches.
       
   449         for (int i = 0; i < newKeyCount; i++) {
       
   450             SwitchFoldable.KeyData data = keyData.get(i);
       
   451             keys[i] = data.key;
       
   452             keyProbabilities[i] = uninitializedProfiles && data.keyProbability > 0.0d ? uniform : normalizationFactor * data.keyProbability;
       
   453             keySuccessors[i] = data.keySuccessor != KeyData.KEY_UNKNOWN ? data.keySuccessor : keySuccessors[newKeyCount];
       
   454         }
       
   455 
       
   456         // Spin an adapter if the value is narrower than an int
       
   457         ValueNode adapter = null;
       
   458         if (((IntegerStamp) switchStamp).getBits() < 32) {
       
   459             adapter = graph.addOrUnique(new SignExtendNode(switchValue, 32));
       
   460         } else {
       
   461             adapter = switchValue;
       
   462         }
       
   463 
       
   464         // Spawn the switch node
       
   465         IntegerSwitchNode toInsert = new IntegerSwitchNode(adapter, successors.size(), keys, keyProbabilities, keySuccessors);
       
   466         graph.add(toInsert);
       
   467 
       
   468         // Detach the cascade from the graph
       
   469         lowestSwitchNode.cutOffLowestCascadeNode();
       
   470         iteratingNode = lowestSwitchNode;
       
   471         while (iteratingNode != null) {
       
   472             if (iteratingNode != lowestSwitchNode) {
       
   473                 iteratingNode.cutOffCascadeNode();
       
   474             }
       
   475             iteratingNode = Helper.getParentSwitchNode(iteratingNode, switchValue);
       
   476         }
       
   477 
       
   478         // Place the new Switch node
       
   479         topMostSwitchNode.asNode().replaceAtPredecessor(toInsert);
       
   480         topMostSwitchNode.asNode().replaceAtUsages(toInsert);
       
   481 
       
   482         // Attach the branches to the switch.
       
   483         int pos = 0;
       
   484         for (AbstractBeginNode begin : successors.list) {
       
   485             if (begin.isUnregistered()) {
       
   486                 graph.add(begin.next());
       
   487                 graph.add(begin);
       
   488                 begin.setNext(begin.next());
       
   489             }
       
   490             toInsert.setBlockSuccessor(pos++, begin);
       
   491         }
       
   492 
       
   493         // Remove the cascade and unreachable code
       
   494         GraphUtil.killCFG((FixedNode) topMostSwitchNode);
       
   495         for (AbstractBeginNode duplicate : potentiallyUnreachable.list) {
       
   496             if (duplicate.predecessor() == null) {
       
   497                 // Make sure the duplicate is not reachable.
       
   498                 assert duplicate.isAlive();
       
   499                 GraphUtil.killCFG(duplicate);
       
   500             }
       
   501         }
       
   502 
       
   503         tool.addToWorkList(toInsert);
       
   504 
       
   505         return true;
       
   506     }
       
   507 }