< prev index next >

src/jdk.internal.vm.compiler/share/classes/org.graalvm.compiler.core.amd64/src/org/graalvm/compiler/core/amd64/AMD64AddressLowering.java

Print this page

        

@@ -21,51 +21,65 @@
  * questions.
  */
 
 package org.graalvm.compiler.core.amd64;
 
-import jdk.vm.ci.meta.JavaConstant;
-
-import org.graalvm.compiler.core.common.NumUtil;
 import org.graalvm.compiler.asm.amd64.AMD64Address.Scale;
+import org.graalvm.compiler.core.common.NumUtil;
 import org.graalvm.compiler.core.common.type.IntegerStamp;
 import org.graalvm.compiler.debug.DebugContext;
+import org.graalvm.compiler.nodes.StructuredGraph;
 import org.graalvm.compiler.nodes.ValueNode;
 import org.graalvm.compiler.nodes.calc.AddNode;
 import org.graalvm.compiler.nodes.calc.LeftShiftNode;
+import org.graalvm.compiler.nodes.calc.NegateNode;
 import org.graalvm.compiler.nodes.calc.ZeroExtendNode;
 import org.graalvm.compiler.nodes.memory.address.AddressNode;
 import org.graalvm.compiler.phases.common.AddressLoweringPhase.AddressLowering;
 
-public class AMD64AddressLowering extends AddressLowering {
+import jdk.vm.ci.meta.JavaConstant;
 
+public class AMD64AddressLowering extends AddressLowering {
     @Override
     public AddressNode lower(ValueNode address) {
         return lower(address, null);
     }
 
     @Override
     public AddressNode lower(ValueNode base, ValueNode offset) {
         AMD64AddressNode ret = new AMD64AddressNode(base, offset);
+        StructuredGraph graph = base.graph();
+
         boolean changed;
         do {
-            changed = improve(base.getDebug(), ret);
+            changed = improve(graph, base.getDebug(), ret, false, false);
         } while (changed);
-        return base.graph().unique(ret);
+
+        return graph.unique(ret);
     }
 
     /**
-     * @param debug
+     * Tries to optimize addresses so that they match the AMD64-specific addressing mode better
+     * (base + index * scale + displacement).
+     *
+     * @param graph the current graph
+     * @param debug the current debug context
+     * @param ret the address that should be optimized
+     * @param isBaseNegated determines if the address base is negated. if so, all values that are
+     *            extracted from the base will be negated as well
+     * @param isIndexNegated determines if the index is negated. if so, all values that are
+     *            extracted from the index will be negated as well
+     * @return true if the address was modified
      */
-    protected boolean improve(DebugContext debug, AMD64AddressNode ret) {
-        ValueNode newBase = improveInput(ret, ret.getBase(), 0);
+    protected boolean improve(StructuredGraph graph, DebugContext debug, AMD64AddressNode ret, boolean isBaseNegated, boolean isIndexNegated) {
+        ValueNode newBase = improveInput(ret, ret.getBase(), 0, isBaseNegated);
         if (newBase != ret.getBase()) {
             ret.setBase(newBase);
             return true;
         }
 
-        ValueNode newIdx = improveInput(ret, ret.getIndex(), ret.getScale().log2);
+        ValueNode newIdx = improveInput(ret, ret.getIndex(), ret.getScale().log2, isIndexNegated);
         if (newIdx != ret.getIndex()) {
             ret.setIndex(newIdx);
             return true;
         }
 

@@ -81,74 +95,156 @@
                 }
             }
         }
 
         if (ret.getScale() == Scale.Times1) {
-            if (ret.getBase() == null || ret.getIndex() == null) {
-                if (ret.getBase() instanceof AddNode) {
+            if (ret.getIndex() == null && ret.getBase() instanceof AddNode) {
                     AddNode add = (AddNode) ret.getBase();
                     ret.setBase(add.getX());
-                    ret.setIndex(add.getY());
+                ret.setIndex(considerNegation(graph, add.getY(), isBaseNegated));
                     return true;
-                } else if (ret.getIndex() instanceof AddNode) {
+            } else if (ret.getBase() == null && ret.getIndex() instanceof AddNode) {
                     AddNode add = (AddNode) ret.getIndex();
-                    ret.setBase(add.getX());
+                ret.setBase(considerNegation(graph, add.getX(), isIndexNegated));
                     ret.setIndex(add.getY());
                     return true;
                 }
-            }
 
             if (ret.getBase() instanceof LeftShiftNode && !(ret.getIndex() instanceof LeftShiftNode)) {
                 ValueNode tmp = ret.getBase();
-                ret.setBase(ret.getIndex());
-                ret.setIndex(tmp);
+                ret.setBase(considerNegation(graph, ret.getIndex(), isIndexNegated != isBaseNegated));
+                ret.setIndex(considerNegation(graph, tmp, isIndexNegated != isBaseNegated));
                 return true;
             }
         }
 
+        return improveNegation(graph, debug, ret, isBaseNegated, isIndexNegated);
+    }
+
+    private boolean improveNegation(StructuredGraph graph, DebugContext debug, AMD64AddressNode ret, boolean originalBaseNegated, boolean originalIndexNegated) {
+        boolean baseNegated = originalBaseNegated;
+        boolean indexNegated = originalIndexNegated;
+
+        ValueNode originalBase = ret.getBase();
+        ValueNode originalIndex = ret.getIndex();
+
+        if (ret.getBase() instanceof NegateNode) {
+            NegateNode negate = (NegateNode) ret.getBase();
+            ret.setBase(negate.getValue());
+            baseNegated = !baseNegated;
+        }
+
+        if (ret.getIndex() instanceof NegateNode) {
+            NegateNode negate = (NegateNode) ret.getIndex();
+            ret.setIndex(negate.getValue());
+            indexNegated = !indexNegated;
+        }
+
+        if (baseNegated != originalBaseNegated || indexNegated != originalIndexNegated) {
+            ValueNode base = ret.getBase();
+            ValueNode index = ret.getIndex();
+
+            boolean improved = improve(graph, debug, ret, baseNegated, indexNegated);
+            if (baseNegated != originalBaseNegated) {
+                if (base == ret.getBase()) {
+                    ret.setBase(originalBase);
+                } else if (ret.getBase() != null) {
+                    ret.setBase(graph.maybeAddOrUnique(NegateNode.create(ret.getBase())));
+                }
+            }
+
+            if (indexNegated != originalIndexNegated) {
+                if (index == ret.getIndex()) {
+                    ret.setIndex(originalIndex);
+                } else if (ret.getIndex() != null) {
+                    ret.setIndex(graph.maybeAddOrUnique(NegateNode.create(ret.getIndex())));
+                }
+            }
+            return improved;
+        } else {
+            assert ret.getBase() == originalBase && ret.getIndex() == originalIndex;
+        }
         return false;
     }
 
-    private static ValueNode improveInput(AMD64AddressNode address, ValueNode node, int shift) {
+    private static ValueNode considerNegation(StructuredGraph graph, ValueNode value, boolean negate) {
+        if (negate && value != null) {
+            return graph.maybeAddOrUnique(NegateNode.create(value));
+        }
+        return value;
+    }
+
+    private ValueNode improveInput(AMD64AddressNode address, ValueNode node, int shift, boolean negateExtractedDisplacement) {
         if (node == null) {
             return null;
         }
 
         JavaConstant c = node.asJavaConstant();
         if (c != null) {
-            return improveConstDisp(address, node, c, null, shift);
+            return improveConstDisp(address, node, c, null, shift, negateExtractedDisplacement);
         } else {
-            if (node.stamp() instanceof IntegerStamp && ((IntegerStamp) node.stamp()).getBits() == 64) {
-                if (node instanceof ZeroExtendNode) {
-                    if (((ZeroExtendNode) node).getInputBits() == 32) {
+            if (node.stamp() instanceof IntegerStamp) {
+                if (node instanceof ZeroExtendNode && (((ZeroExtendNode) node).getInputBits() == 32)) {
                         /*
-                         * We can just swallow a zero-extend from 32 bit to 64 bit because the upper
-                         * half of the register will always be zero.
+                     * we can't just swallow all zero-extends as we might encounter something like
+                     * the following: ZeroExtend(Add(negativeValue, positiveValue)).
+                     *
+                     * if we swallow the zero-extend in this case and subsequently optimize the add,
+                     * we might end up with a negative value that has less than 64 bits in base or
+                     * index. such a value would require sign extension instead of zero-extension
+                     * but the backend can only do zero-extension. if we ever want to optimize that
+                     * further, we would also need to be careful about over-/underflows.
+                     *
+                     * furthermore, we also can't swallow zero-extends with less than 32 bits as
+                     * most of these values are immediately sign-extended to 32 bit by the backend
+                     * (therefore, the subsequent implicit zero-extension to 64 bit won't do what we
+                     * expect).
                          */
-                        return ((ZeroExtendNode) node).getValue();
+                    ValueNode value = ((ZeroExtendNode) node).getValue();
+                    if (!mightBeOptimized(value)) {
+                        // if the value is not optimized further by the address lowering, then we
+                        // can safely rely on the backend doing the implicitly zero-extension.
+                        return value;
+                    }
                     }
-                } else if (node instanceof AddNode) {
+
+                if (node instanceof AddNode) {
                     AddNode add = (AddNode) node;
                     if (add.getX().isConstant()) {
-                        return improveConstDisp(address, node, add.getX().asJavaConstant(), add.getY(), shift);
+                        return improveConstDisp(address, node, add.getX().asJavaConstant(), add.getY(), shift, negateExtractedDisplacement);
                     } else if (add.getY().isConstant()) {
-                        return improveConstDisp(address, node, add.getY().asJavaConstant(), add.getX(), shift);
+                        return improveConstDisp(address, node, add.getY().asJavaConstant(), add.getX(), shift, negateExtractedDisplacement);
                     }
                 }
             }
         }
 
         return node;
     }
 
-    private static ValueNode improveConstDisp(AMD64AddressNode address, ValueNode original, JavaConstant c, ValueNode other, int shift) {
+    /**
+     * This method returns true for all nodes that might be optimized by the address lowering.
+     */
+    protected boolean mightBeOptimized(ValueNode value) {
+        return value instanceof AddNode || value instanceof LeftShiftNode || value instanceof NegateNode || value instanceof ZeroExtendNode;
+    }
+
+    private static ValueNode improveConstDisp(AMD64AddressNode address, ValueNode original, JavaConstant c, ValueNode other, int shift, boolean negateExtractedDisplacement) {
         if (c.getJavaKind().isNumericInteger()) {
-            long disp = address.getDisplacement();
-            disp += c.asLong() << shift;
-            if (NumUtil.isInt(disp)) {
-                address.setDisplacement((int) disp);
+            long delta = c.asLong() << shift;
+            if (updateDisplacement(address, delta, negateExtractedDisplacement)) {
                 return other;
             }
         }
         return original;
     }
+
+    protected static boolean updateDisplacement(AMD64AddressNode address, long displacementDelta, boolean negateDelta) {
+        long sign = negateDelta ? -1 : 1;
+        long disp = address.getDisplacement() + displacementDelta * sign;
+        if (NumUtil.isInt(disp)) {
+            address.setDisplacement((int) disp);
+            return true;
+        }
+        return false;
+    }
 }
< prev index next >