1 /*
   2  * Copyright (c) 2013, 2015, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 package org.graalvm.compiler.lir.amd64;
  24 
  25 import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.ILLEGAL;
  26 import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.REG;
  27 import static jdk.vm.ci.code.ValueUtil.asRegister;
  28 
  29 import java.lang.reflect.Array;
  30 import java.lang.reflect.Field;
  31 
  32 import org.graalvm.compiler.asm.Label;
  33 import org.graalvm.compiler.asm.amd64.AMD64Address;
  34 import org.graalvm.compiler.asm.amd64.AMD64Address.Scale;
  35 import org.graalvm.compiler.asm.amd64.AMD64Assembler.ConditionFlag;
  36 import org.graalvm.compiler.asm.amd64.AMD64MacroAssembler;
  37 import org.graalvm.compiler.core.common.LIRKind;
  38 import org.graalvm.compiler.lir.LIRInstructionClass;
  39 import org.graalvm.compiler.lir.Opcode;
  40 import org.graalvm.compiler.lir.asm.CompilationResultBuilder;
  41 import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
  42 
  43 import jdk.vm.ci.amd64.AMD64;
  44 import jdk.vm.ci.amd64.AMD64.CPUFeature;
  45 import jdk.vm.ci.amd64.AMD64Kind;
  46 import jdk.vm.ci.code.Register;
  47 import jdk.vm.ci.code.TargetDescription;
  48 import jdk.vm.ci.meta.JavaKind;
  49 import jdk.vm.ci.meta.Value;
  50 import sun.misc.Unsafe;
  51 
  52 /**
  53  * Emits code which compares two arrays of the same length. If the CPU supports any vector
  54  * instructions specialized code is emitted to leverage these instructions.
  55  */
  56 @Opcode("ARRAY_EQUALS")
  57 public final class AMD64ArrayEqualsOp extends AMD64LIRInstruction {
  58     public static final LIRInstructionClass<AMD64ArrayEqualsOp> TYPE = LIRInstructionClass.create(AMD64ArrayEqualsOp.class);
  59 
  60     private final JavaKind kind;
  61     private final int arrayBaseOffset;
  62     private final int arrayIndexScale;
  63 
  64     @Def({REG}) protected Value resultValue;
  65     @Alive({REG}) protected Value array1Value;
  66     @Alive({REG}) protected Value array2Value;
  67     @Alive({REG}) protected Value lengthValue;
  68     @Temp({REG}) protected Value temp1;
  69     @Temp({REG}) protected Value temp2;
  70     @Temp({REG}) protected Value temp3;
  71     @Temp({REG}) protected Value temp4;
  72     @Temp({REG, ILLEGAL}) protected Value vectorTemp1;
  73     @Temp({REG, ILLEGAL}) protected Value vectorTemp2;
  74 
  75     public AMD64ArrayEqualsOp(LIRGeneratorTool tool, JavaKind kind, Value result, Value array1, Value array2, Value length) {
  76         super(TYPE);
  77         this.kind = kind;
  78 
  79         Class<?> arrayClass = Array.newInstance(kind.toJavaClass(), 0).getClass();
  80         this.arrayBaseOffset = UNSAFE.arrayBaseOffset(arrayClass);
  81         this.arrayIndexScale = UNSAFE.arrayIndexScale(arrayClass);
  82 
  83         this.resultValue = result;
  84         this.array1Value = array1;
  85         this.array2Value = array2;
  86         this.lengthValue = length;
  87 
  88         // Allocate some temporaries.
  89         this.temp1 = tool.newVariable(LIRKind.unknownReference(tool.target().arch.getWordKind()));
  90         this.temp2 = tool.newVariable(LIRKind.unknownReference(tool.target().arch.getWordKind()));
  91         this.temp3 = tool.newVariable(LIRKind.value(tool.target().arch.getWordKind()));
  92         this.temp4 = tool.newVariable(LIRKind.value(tool.target().arch.getWordKind()));
  93 
  94         // We only need the vector temporaries if we generate SSE code.
  95         if (supportsSSE41(tool.target())) {
  96             this.vectorTemp1 = tool.newVariable(LIRKind.value(AMD64Kind.DOUBLE));
  97             this.vectorTemp2 = tool.newVariable(LIRKind.value(AMD64Kind.DOUBLE));
  98         } else {
  99             this.vectorTemp1 = Value.ILLEGAL;
 100             this.vectorTemp2 = Value.ILLEGAL;
 101         }
 102     }
 103 
 104     @Override
 105     public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
 106         Register result = asRegister(resultValue);
 107         Register array1 = asRegister(temp1);
 108         Register array2 = asRegister(temp2);
 109         Register length = asRegister(temp3);
 110 
 111         Label trueLabel = new Label();
 112         Label falseLabel = new Label();
 113         Label done = new Label();
 114 
 115         // Load array base addresses.
 116         masm.leaq(array1, new AMD64Address(asRegister(array1Value), arrayBaseOffset));
 117         masm.leaq(array2, new AMD64Address(asRegister(array2Value), arrayBaseOffset));
 118 
 119         // Get array length in bytes.
 120         masm.imull(length, asRegister(lengthValue), arrayIndexScale);
 121         masm.movl(result, length); // copy
 122 
 123         if (supportsAVX2(crb.target)) {
 124             emitAVXCompare(crb, masm, result, array1, array2, length, trueLabel, falseLabel);
 125         } else if (supportsSSE41(crb.target)) {
 126             emitSSE41Compare(crb, masm, result, array1, array2, length, trueLabel, falseLabel);
 127         }
 128 
 129         emit8ByteCompare(crb, masm, result, array1, array2, length, trueLabel, falseLabel);
 130         emitTailCompares(masm, result, array1, array2, length, trueLabel, falseLabel);
 131 
 132         // Return true
 133         masm.bind(trueLabel);
 134         masm.movl(result, 1);
 135         masm.jmpb(done);
 136 
 137         // Return false
 138         masm.bind(falseLabel);
 139         masm.xorl(result, result);
 140 
 141         // That's it
 142         masm.bind(done);
 143     }
 144 
 145     /**
 146      * Returns if the underlying AMD64 architecture supports SSE 4.1 instructions.
 147      *
 148      * @param target target description of the underlying architecture
 149      * @return true if the underlying architecture supports SSE 4.1
 150      */
 151     private static boolean supportsSSE41(TargetDescription target) {
 152         AMD64 arch = (AMD64) target.arch;
 153         return arch.getFeatures().contains(CPUFeature.SSE4_1);
 154     }
 155 
 156     /**
 157      * Vector size used in {@link #emitSSE41Compare}.
 158      */
 159     private static final int SSE4_1_VECTOR_SIZE = 16;
 160 
 161     /**
 162      * Emits code that uses SSE4.1 128-bit (16-byte) vector compares.
 163      */
 164     private void emitSSE41Compare(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register result, Register array1, Register array2, Register length, Label trueLabel, Label falseLabel) {
 165         assert supportsSSE41(crb.target);
 166 
 167         Register vector1 = asRegister(vectorTemp1, AMD64Kind.DOUBLE);
 168         Register vector2 = asRegister(vectorTemp2, AMD64Kind.DOUBLE);
 169 
 170         Label loop = new Label();
 171         Label compareTail = new Label();
 172 
 173         // Compare 16-byte vectors
 174         masm.andl(result, SSE4_1_VECTOR_SIZE - 1); // tail count (in bytes)
 175         masm.andl(length, ~(SSE4_1_VECTOR_SIZE - 1)); // vector count (in bytes)
 176         masm.jccb(ConditionFlag.Zero, compareTail);
 177 
 178         masm.leaq(array1, new AMD64Address(array1, length, Scale.Times1, 0));
 179         masm.leaq(array2, new AMD64Address(array2, length, Scale.Times1, 0));
 180         masm.negq(length);
 181 
 182         // Align the main loop
 183         masm.align(crb.target.wordSize * 2);
 184         masm.bind(loop);
 185         masm.movdqu(vector1, new AMD64Address(array1, length, Scale.Times1, 0));
 186         masm.movdqu(vector2, new AMD64Address(array2, length, Scale.Times1, 0));
 187         masm.pxor(vector1, vector2);
 188         masm.ptest(vector1, vector1);
 189         masm.jcc(ConditionFlag.NotZero, falseLabel);
 190         masm.addq(length, SSE4_1_VECTOR_SIZE);
 191         masm.jcc(ConditionFlag.NotZero, loop);
 192 
 193         masm.testl(result, result);
 194         masm.jcc(ConditionFlag.Zero, trueLabel);
 195 
 196         /*
 197          * Compare the remaining bytes with an unaligned memory load aligned to the end of the
 198          * array.
 199          */
 200         masm.movdqu(vector1, new AMD64Address(array1, result, Scale.Times1, -SSE4_1_VECTOR_SIZE));
 201         masm.movdqu(vector2, new AMD64Address(array2, result, Scale.Times1, -SSE4_1_VECTOR_SIZE));
 202         masm.pxor(vector1, vector2);
 203         masm.ptest(vector1, vector1);
 204         masm.jcc(ConditionFlag.NotZero, falseLabel);
 205         masm.jmp(trueLabel);
 206 
 207         masm.bind(compareTail);
 208         masm.movl(length, result);
 209     }
 210 
 211     /**
 212      * Returns if the underlying AMD64 architecture supports AVX instructions.
 213      *
 214      * @param target target description of the underlying architecture
 215      * @return true if the underlying architecture supports AVX
 216      */
 217     private static boolean supportsAVX2(TargetDescription target) {
 218         AMD64 arch = (AMD64) target.arch;
 219         return arch.getFeatures().contains(CPUFeature.AVX2);
 220     }
 221 
 222     /**
 223      * Vector size used in {@link #emitAVXCompare}.
 224      */
 225     private static final int AVX_VECTOR_SIZE = 32;
 226 
 227     private void emitAVXCompare(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register result, Register array1, Register array2, Register length, Label trueLabel, Label falseLabel) {
 228         assert supportsAVX2(crb.target);
 229 
 230         Register vector1 = asRegister(vectorTemp1, AMD64Kind.DOUBLE);
 231         Register vector2 = asRegister(vectorTemp2, AMD64Kind.DOUBLE);
 232 
 233         Label loop = new Label();
 234         Label compareTail = new Label();
 235 
 236         // Compare 16-byte vectors
 237         masm.andl(result, AVX_VECTOR_SIZE - 1); // tail count (in bytes)
 238         masm.andl(length, ~(AVX_VECTOR_SIZE - 1)); // vector count (in bytes)
 239         masm.jccb(ConditionFlag.Zero, compareTail);
 240 
 241         masm.leaq(array1, new AMD64Address(array1, length, Scale.Times1, 0));
 242         masm.leaq(array2, new AMD64Address(array2, length, Scale.Times1, 0));
 243         masm.negq(length);
 244 
 245         // Align the main loop
 246         masm.align(crb.target.wordSize * 2);
 247         masm.bind(loop);
 248         masm.vmovdqu(vector1, new AMD64Address(array1, length, Scale.Times1, 0));
 249         masm.vmovdqu(vector2, new AMD64Address(array2, length, Scale.Times1, 0));
 250         masm.vpxor(vector1, vector1, vector2);
 251         masm.vptest(vector1, vector1);
 252         masm.jcc(ConditionFlag.NotZero, falseLabel);
 253         masm.addq(length, AVX_VECTOR_SIZE);
 254         masm.jcc(ConditionFlag.NotZero, loop);
 255 
 256         masm.testl(result, result);
 257         masm.jcc(ConditionFlag.Zero, trueLabel);
 258 
 259         /*
 260          * Compare the remaining bytes with an unaligned memory load aligned to the end of the
 261          * array.
 262          */
 263         masm.vmovdqu(vector1, new AMD64Address(array1, result, Scale.Times1, -AVX_VECTOR_SIZE));
 264         masm.vmovdqu(vector2, new AMD64Address(array2, result, Scale.Times1, -AVX_VECTOR_SIZE));
 265         masm.vpxor(vector1, vector1, vector2);
 266         masm.vptest(vector1, vector1);
 267         masm.jcc(ConditionFlag.NotZero, falseLabel);
 268         masm.jmp(trueLabel);
 269 
 270         masm.bind(compareTail);
 271         masm.movl(length, result);
 272     }
 273 
 274     /**
 275      * Vector size used in {@link #emit8ByteCompare}.
 276      */
 277     private static final int VECTOR_SIZE = 8;
 278 
 279     /**
 280      * Emits code that uses 8-byte vector compares.
 281      */
 282     private void emit8ByteCompare(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register result, Register array1, Register array2, Register length, Label trueLabel, Label falseLabel) {
 283         Label loop = new Label();
 284         Label compareTail = new Label();
 285 
 286         Register temp = asRegister(temp4);
 287 
 288         masm.andl(result, VECTOR_SIZE - 1); // tail count (in bytes)
 289         masm.andl(length, ~(VECTOR_SIZE - 1));  // vector count (in bytes)
 290         masm.jccb(ConditionFlag.Zero, compareTail);
 291 
 292         masm.leaq(array1, new AMD64Address(array1, length, Scale.Times1, 0));
 293         masm.leaq(array2, new AMD64Address(array2, length, Scale.Times1, 0));
 294         masm.negq(length);
 295 
 296         // Align the main loop
 297         masm.align(crb.target.wordSize * 2);
 298         masm.bind(loop);
 299         masm.movq(temp, new AMD64Address(array1, length, Scale.Times1, 0));
 300         masm.cmpq(temp, new AMD64Address(array2, length, Scale.Times1, 0));
 301         masm.jccb(ConditionFlag.NotEqual, falseLabel);
 302         masm.addq(length, VECTOR_SIZE);
 303         masm.jccb(ConditionFlag.NotZero, loop);
 304 
 305         masm.testl(result, result);
 306         masm.jccb(ConditionFlag.Zero, trueLabel);
 307 
 308         /*
 309          * Compare the remaining bytes with an unaligned memory load aligned to the end of the
 310          * array.
 311          */
 312         masm.movq(temp, new AMD64Address(array1, result, Scale.Times1, -VECTOR_SIZE));
 313         masm.cmpq(temp, new AMD64Address(array2, result, Scale.Times1, -VECTOR_SIZE));
 314         masm.jccb(ConditionFlag.NotEqual, falseLabel);
 315         masm.jmpb(trueLabel);
 316 
 317         masm.bind(compareTail);
 318         masm.movl(length, result);
 319     }
 320 
 321     /**
 322      * Emits code to compare the remaining 1 to 4 bytes.
 323      */
 324     private void emitTailCompares(AMD64MacroAssembler masm, Register result, Register array1, Register array2, Register length, Label trueLabel, Label falseLabel) {
 325         Label compare2Bytes = new Label();
 326         Label compare1Byte = new Label();
 327 
 328         Register temp = asRegister(temp4);
 329 
 330         if (kind.getByteCount() <= 4) {
 331             // Compare trailing 4 bytes, if any.
 332             masm.testl(result, 4);
 333             masm.jccb(ConditionFlag.Zero, compare2Bytes);
 334             masm.movl(temp, new AMD64Address(array1, 0));
 335             masm.cmpl(temp, new AMD64Address(array2, 0));
 336             masm.jccb(ConditionFlag.NotEqual, falseLabel);
 337 
 338             if (kind.getByteCount() <= 2) {
 339                 // Move array pointers forward.
 340                 masm.leaq(array1, new AMD64Address(array1, 4));
 341                 masm.leaq(array2, new AMD64Address(array2, 4));
 342 
 343                 // Compare trailing 2 bytes, if any.
 344                 masm.bind(compare2Bytes);
 345                 masm.testl(result, 2);
 346                 masm.jccb(ConditionFlag.Zero, compare1Byte);
 347                 masm.movzwl(temp, new AMD64Address(array1, 0));
 348                 masm.movzwl(length, new AMD64Address(array2, 0));
 349                 masm.cmpl(temp, length);
 350                 masm.jccb(ConditionFlag.NotEqual, falseLabel);
 351 
 352                 // The one-byte tail compare is only required for boolean and byte arrays.
 353                 if (kind.getByteCount() <= 1) {
 354                     // Move array pointers forward before we compare the last trailing byte.
 355                     masm.leaq(array1, new AMD64Address(array1, 2));
 356                     masm.leaq(array2, new AMD64Address(array2, 2));
 357 
 358                     // Compare trailing byte, if any.
 359                     masm.bind(compare1Byte);
 360                     masm.testl(result, 1);
 361                     masm.jccb(ConditionFlag.Zero, trueLabel);
 362                     masm.movzbl(temp, new AMD64Address(array1, 0));
 363                     masm.movzbl(length, new AMD64Address(array2, 0));
 364                     masm.cmpl(temp, length);
 365                     masm.jccb(ConditionFlag.NotEqual, falseLabel);
 366                 } else {
 367                     masm.bind(compare1Byte);
 368                 }
 369             } else {
 370                 masm.bind(compare2Bytes);
 371             }
 372         }
 373     }
 374 
 375     private static final Unsafe UNSAFE = initUnsafe();
 376 
 377     private static Unsafe initUnsafe() {
 378         try {
 379             return Unsafe.getUnsafe();
 380         } catch (SecurityException se) {
 381             try {
 382                 Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
 383                 theUnsafe.setAccessible(true);
 384                 return (Unsafe) theUnsafe.get(Unsafe.class);
 385             } catch (Exception e) {
 386                 throw new RuntimeException("exception while trying to get Unsafe", e);
 387             }
 388         }
 389     }
 390 }