1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 
  25 package org.graalvm.compiler.phases.common;
  26 
  27 import jdk.internal.vm.compiler.collections.Pair;
  28 import org.graalvm.compiler.core.common.type.IntegerStamp;
  29 import org.graalvm.compiler.nodes.ConstantNode;
  30 import org.graalvm.compiler.nodes.FixedNode;
  31 import org.graalvm.compiler.nodes.NodeView;
  32 import org.graalvm.compiler.nodes.StructuredGraph;
  33 import org.graalvm.compiler.nodes.ValueNode;
  34 import org.graalvm.compiler.nodes.calc.BinaryArithmeticNode;
  35 import org.graalvm.compiler.nodes.calc.IntegerDivRemNode;
  36 import org.graalvm.compiler.nodes.calc.IntegerMulHighNode;
  37 import org.graalvm.compiler.nodes.calc.MulNode;
  38 import org.graalvm.compiler.nodes.calc.NarrowNode;
  39 import org.graalvm.compiler.nodes.calc.RightShiftNode;
  40 import org.graalvm.compiler.nodes.calc.SignExtendNode;
  41 import org.graalvm.compiler.nodes.calc.SignedDivNode;
  42 import org.graalvm.compiler.nodes.calc.SignedRemNode;
  43 import org.graalvm.compiler.nodes.calc.UnsignedRightShiftNode;
  44 import org.graalvm.compiler.phases.Phase;
  45 
  46 import jdk.vm.ci.code.CodeUtil;
  47 
  48 public class OptimizeDivPhase extends Phase {
  49 
  50     @Override
  51     protected void run(StructuredGraph graph) {
  52         for (IntegerDivRemNode rem : graph.getNodes(IntegerDivRemNode.TYPE)) {
  53             if (rem instanceof SignedRemNode && divByNonZeroConstant(rem)) {
  54                 optimizeRem(rem);
  55             }
  56         }
  57         for (IntegerDivRemNode div : graph.getNodes(IntegerDivRemNode.TYPE)) {
  58             if (div instanceof SignedDivNode && divByNonZeroConstant(div)) {
  59                 optimizeSignedDiv((SignedDivNode) div);
  60             }
  61         }
  62     }
  63 
  64     @Override
  65     public float codeSizeIncrease() {
  66         return 5.0f;
  67     }
  68 
  69     protected static boolean divByNonZeroConstant(IntegerDivRemNode divRemNode) {
  70         return divRemNode.getY().isConstant() && divRemNode.getY().asJavaConstant().asLong() != 0;
  71     }
  72 
  73     protected final void optimizeRem(IntegerDivRemNode rem) {
  74         assert rem.getOp() == IntegerDivRemNode.Op.REM;
  75         // Java spec 15.17.3.: (a/b)*b+(a%b) == a
  76         // so a%b == a-(a/b)*b
  77         StructuredGraph graph = rem.graph();
  78         ValueNode div = findDivForRem(rem);
  79         ValueNode mul = BinaryArithmeticNode.mul(graph, div, rem.getY(), NodeView.DEFAULT);
  80         ValueNode result = BinaryArithmeticNode.sub(graph, rem.getX(), mul, NodeView.DEFAULT);
  81         graph.replaceFixedWithFloating(rem, result);
  82     }
  83 
  84     private ValueNode findDivForRem(IntegerDivRemNode rem) {
  85         if (rem.next() instanceof IntegerDivRemNode) {
  86             IntegerDivRemNode div = (IntegerDivRemNode) rem.next();
  87             if (div.getOp() == IntegerDivRemNode.Op.DIV && div.getType() == rem.getType() && div.getX() == rem.getX() && div.getY() == rem.getY()) {
  88                 return div;
  89             }
  90         }
  91         if (rem.predecessor() instanceof IntegerDivRemNode) {
  92             IntegerDivRemNode div = (IntegerDivRemNode) rem.predecessor();
  93             if (div.getOp() == IntegerDivRemNode.Op.DIV && div.getType() == rem.getType() && div.getX() == rem.getX() && div.getY() == rem.getY()) {
  94                 return div;
  95             }
  96         }
  97 
  98         // not found, create a new one (will be optimized away later)
  99         ValueNode div = rem.graph().addOrUniqueWithInputs(createDiv(rem));
 100         if (div instanceof FixedNode) {
 101             rem.graph().addAfterFixed(rem, (FixedNode) div);
 102         }
 103         return div;
 104     }
 105 
 106     protected ValueNode createDiv(IntegerDivRemNode rem) {
 107         assert rem instanceof SignedRemNode;
 108         return SignedDivNode.create(rem.getX(), rem.getY(), rem.getZeroCheck(), NodeView.DEFAULT);
 109     }
 110 
 111     protected static void optimizeSignedDiv(SignedDivNode div) {
 112         ValueNode forX = div.getX();
 113         long c = div.getY().asJavaConstant().asLong();
 114         assert c != 1 && c != -1 && c != 0;
 115 
 116         IntegerStamp dividendStamp = (IntegerStamp) forX.stamp(NodeView.DEFAULT);
 117         int bitSize = dividendStamp.getBits();
 118         Pair<Long, Integer> nums = magicDivideConstants(c, bitSize);
 119         long magicNum = nums.getLeft().longValue();
 120         int shiftNum = nums.getRight().intValue();
 121         assert shiftNum >= 0;
 122         ConstantNode m = ConstantNode.forLong(magicNum);
 123 
 124         ValueNode value;
 125         if (bitSize == 32) {
 126             value = new MulNode(new SignExtendNode(forX, 64), m);
 127             if ((c > 0 && magicNum < 0) || (c < 0 && magicNum > 0)) {
 128                 // Get upper 32-bits of the result
 129                 value = NarrowNode.create(new RightShiftNode(value, ConstantNode.forInt(32)), 32, NodeView.DEFAULT);
 130                 if (c > 0) {
 131                     value = BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT);
 132                 } else {
 133                     value = BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
 134                 }
 135                 if (shiftNum > 0) {
 136                     value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
 137                 }
 138             } else {
 139                 value = new RightShiftNode(value, ConstantNode.forInt(32 + shiftNum));
 140                 value = new NarrowNode(value, Integer.SIZE);
 141             }
 142         } else {
 143             assert bitSize == 64;
 144             value = new IntegerMulHighNode(forX, m);
 145             if (c > 0 && magicNum < 0) {
 146                 value = BinaryArithmeticNode.add(value, forX, NodeView.DEFAULT);
 147             } else if (c < 0 && magicNum > 0) {
 148                 value = BinaryArithmeticNode.sub(value, forX, NodeView.DEFAULT);
 149             }
 150             if (shiftNum > 0) {
 151                 value = new RightShiftNode(value, ConstantNode.forInt(shiftNum));
 152             }
 153         }
 154 
 155         if (c < 0) {
 156             ConstantNode s = ConstantNode.forInt(bitSize - 1);
 157             ValueNode sign = UnsignedRightShiftNode.create(value, s, NodeView.DEFAULT);
 158             value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
 159         } else if (dividendStamp.canBeNegative()) {
 160             ConstantNode s = ConstantNode.forInt(bitSize - 1);
 161             ValueNode sign = UnsignedRightShiftNode.create(forX, s, NodeView.DEFAULT);
 162             value = BinaryArithmeticNode.add(value, sign, NodeView.DEFAULT);
 163         }
 164 
 165         StructuredGraph graph = div.graph();
 166         graph.replaceFixed(div, graph.addOrUniqueWithInputs(value));
 167     }
 168 
 169     /**
 170      * Borrowed from Hacker's Delight by Henry S. Warren, Jr. Figure 10-1.
 171      */
 172     private static Pair<Long, Integer> magicDivideConstants(long divisor, int size) {
 173         final long twoW = 1L << (size - 1);                // 2 ^ (size - 1).
 174         long t = twoW + (divisor >>> 63);
 175         long ad = Math.abs(divisor);
 176         long anc = t - 1 - Long.remainderUnsigned(t, ad);  // Absolute value of nc.
 177         long q1 = Long.divideUnsigned(twoW, anc);          // Init. q1 = 2**p/|nc|.
 178         long r1 = Long.remainderUnsigned(twoW, anc);       // Init. r1 = rem(2**p, |nc|).
 179         long q2 = Long.divideUnsigned(twoW, ad);           // Init. q2 = 2**p/|d|.
 180         long r2 = Long.remainderUnsigned(twoW, ad);        // Init. r2 = rem(2**p, |d|).
 181         long delta;
 182 
 183         int p = size - 1;                                  // Init. p.
 184         do {
 185             p = p + 1;
 186             q1 = 2 * q1;                                   // Update q1 = 2**p/|nc|.
 187             r1 = 2 * r1;                                   // Update r1 = rem(2**p, |nc|).
 188             if (Long.compareUnsigned(r1, anc) >= 0) {      // Must be an unsigned comparison.
 189                 q1 = q1 + 1;
 190                 r1 = r1 - anc;
 191             }
 192             q2 = 2 * q2;                                   // Update q2 = 2**p/|d|.
 193             r2 = 2 * r2;                                   // Update r2 = rem(2**p, |d|).
 194             if (Long.compareUnsigned(r2, ad) >= 0) {       // Must be an unsigned comparison.
 195                 q2 = q2 + 1;
 196                 r2 = r2 - ad;
 197             }
 198             delta = ad - r2;
 199         } while (Long.compareUnsigned(q1, delta) < 0 || (q1 == delta && r1 == 0));
 200 
 201         long magic = CodeUtil.signExtend(q2 + 1, size);
 202         if (divisor < 0) {
 203             magic = -magic;
 204         }
 205         return Pair.create(magic, p - size);
 206     }
 207 
 208 }