1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have
  23  * questions.
  24  */
  25 package benchmark.jdk.incubator.vector;
  26 
  27 import jdk.incubator.vector.IntVector;
  28 import jdk.incubator.vector.IntVector.IntSpecies;
  29 import jdk.incubator.vector.Vector;
  30 import jdk.incubator.vector.Vector.Mask;
  31 import org.openjdk.jmh.annotations.*;
  32 
  33 import java.util.concurrent.TimeUnit;
  34 
  35 /**
  36  * Inspired by "Sorting an AVX512 register"
  37  *   http://0x80.pl/articles/avx512-sort-register.html
  38  */
  39 @BenchmarkMode(Mode.Throughput)
  40 @Warmup(iterations = 3, time = 1)
  41 @Measurement(iterations = 5, time = 1)
  42 @OutputTimeUnit(TimeUnit.MILLISECONDS)
  43 @State(Scope.Benchmark)
  44 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
  45 public class SortVector extends AbstractVectorBenchmark {
  46     @Param({"64", "1024", "65536"})
  47     int size;
  48 
  49     int[] in, out;
  50 
  51     @Setup
  52     public void setup() {
  53         size = size + (size % 16); // FIXME: process tails
  54         in  = fillInt(size, i -> RANDOM.nextInt());
  55         out = new int[size];
  56     }
  57 
  58     @Benchmark
  59     public void sortVectorI64() {
  60         sort(I64);
  61     }
  62 
  63     @Benchmark
  64     public void sortVectorI128() {
  65         sort(I128);
  66     }
  67 
  68     @Benchmark
  69     public void sortVectorI256() {
  70         sort(I256);
  71     }
  72 
  73     @Benchmark
  74     public void sortVectorI512() {
  75         sort(I512);
  76     }
  77 
  78 
  79     void sort(IntSpecies spec) {
  80         var iota = (IntVector) IntVector.shuffleIota(spec).toVector(); // [ 0 1 ... n ]
  81 
  82         var result = spec.broadcast(0);
  83         var index = spec.broadcast(0);
  84         var incr = spec.broadcast(1);
  85 
  86         for (int i = 0; i < in.length; i += spec.length()) {
  87             var input = IntVector.fromArray(spec, in, i);
  88 
  89             for (int j = 0; j < input.length(); j++) {
  90                 var shuf = index.toShuffle();
  91                 var b = input.rearrange(shuf); // broadcast j-th element
  92                 var lt = input.lessThan(b).trueCount();
  93                 var eq = input.equal(b).trueCount();
  94 
  95                 // int/long -> mask?
  96                 // int m = (1 << (lt + eq)) - (1 << lt);
  97                 // var mask = masks[lt + eq].xor(masks[lt]);
  98                 // var mask = masks[lt + eq].and(masks[lt].not());
  99                 //
 100                 // masks[i] =  [ 0 0 ... 0 1 ... 1 ]
 101                 //                      i-th
 102                 var m = iota.lessThan(lt + eq).and(iota.lessThan(lt).not());
 103 
 104                 result = result.blend(b, m);
 105                 index = index.add(incr);
 106             }
 107             result.intoArray(out, i);
 108         }
 109     }
 110 }