1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have
  23  * questions.
  24  */
  25 package benchmark.jdk.incubator.vector;
  26 
  27 import jdk.incubator.vector.IntVector;
  28 import jdk.incubator.vector.Vector;
  29 import jdk.incubator.vector.VectorMask;
  30 import jdk.incubator.vector.VectorSpecies;
  31 import jdk.incubator.vector.VectorShuffle;
  32 import org.openjdk.jmh.annotations.*;
  33 
  34 import java.util.concurrent.TimeUnit;
  35 
  36 /**
  37  * Inspired by "Sorting an AVX512 register"
  38  *   http://0x80.pl/articles/avx512-sort-register.html
  39  */
  40 @BenchmarkMode(Mode.Throughput)
  41 @Warmup(iterations = 3, time = 1)
  42 @Measurement(iterations = 5, time = 1)
  43 @OutputTimeUnit(TimeUnit.MILLISECONDS)
  44 @State(Scope.Benchmark)
  45 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
  46 public class SortVector extends AbstractVectorBenchmark {
  47     @Param({"64", "1024", "65536"})
  48     int size;
  49 
  50     int[] in, out;
  51 
  52     @Setup
  53     public void setup() {
  54         size = size + (size % 16); // FIXME: process tails
  55         in  = fillInt(size, i -> RANDOM.nextInt());
  56         out = new int[size];
  57     }
  58 
  59     @Benchmark
  60     public void sortVectorI64() {
  61         sort(I64);
  62     }
  63 
  64     @Benchmark
  65     public void sortVectorI128() {
  66         sort(I128);
  67     }
  68 
  69     @Benchmark
  70     public void sortVectorI256() {
  71         sort(I256);
  72     }
  73 
  74     @Benchmark
  75     public void sortVectorI512() {
  76         sort(I512);
  77     }
  78 
  79 
  80     void sort(VectorSpecies<Integer> spec) {
  81         var iota = (IntVector) VectorShuffle.shuffleIota(spec).toVector(); // [ 0 1 ... n ]
  82 
  83         var result = IntVector.broadcast(spec, 0);
  84         var index = IntVector.broadcast(spec, 0);
  85         var incr = IntVector.broadcast(spec, 1);
  86 
  87         for (int i = 0; i < in.length; i += spec.length()) {
  88             var input = IntVector.fromArray(spec, in, i);
  89 
  90             for (int j = 0; j < input.length(); j++) {
  91                 var shuf = index.toShuffle();
  92                 var b = input.rearrange(shuf); // broadcast j-th element
  93                 var lt = input.lessThan(b).trueCount();
  94                 var eq = input.equal(b).trueCount();
  95 
  96                 // int/long -> mask?
  97                 // int m = (1 << (lt + eq)) - (1 << lt);
  98                 // var mask = masks[lt + eq].xor(masks[lt]);
  99                 // var mask = masks[lt + eq].and(masks[lt].not());
 100                 //
 101                 // masks[i] =  [ 0 0 ... 0 1 ... 1 ]
 102                 //                      i-th
 103                 var m = iota.lessThan(lt + eq).and(iota.lessThan(lt).not());
 104 
 105                 result = result.blend(b, m);
 106                 index = index.add(incr);
 107             }
 108             result.intoArray(out, i);
 109         }
 110     }
 111 }