1 /* 2 * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have 23 * questions. 24 */ 25 package benchmark.jdk.incubator.vector; 26 27 import jdk.incubator.vector.*; 28 import org.openjdk.jmh.annotations.*; 29 30 import java.util.concurrent.TimeUnit; 31 32 import static org.junit.jupiter.api.Assertions.*; 33 34 // Inspired by "SIMDized sum of all bytes in the array" 35 // http://0x80.pl/notesen/2018-10-24-sse-sumbytes.html 36 // 37 // C/C++ equivalent: https://github.com/WojciechMula/toys/tree/master/sse-sumbytes 38 // 39 @BenchmarkMode(Mode.Throughput) 40 @Warmup(iterations = 3, time = 1) 41 @Measurement(iterations = 5, time = 1) 42 @OutputTimeUnit(TimeUnit.MILLISECONDS) 43 @State(Scope.Benchmark) 44 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) 45 public class SumOfUnsignedBytes extends AbstractVectorBenchmark { 46 47 @Param({"64", "1024", "4096"}) 48 int size; 49 50 private byte[] data; 51 52 @Setup 53 public void init() { 54 size = size + size % 32; // FIXME: process tails 55 data = fillByte(size, i -> (byte)(int)i); 56 57 int sum = scalar(); 58 assertEquals(vectorInt(), sum); 59 assertEquals(vectorShort(), sum); 60 //assertEquals(vectorByte(), sum); 61 //assertEquals(vectorSAD(), sum); 62 } 63 64 @Benchmark 65 public int scalar() { 66 int sum = 0; 67 for (int i = 0; i < data.length; i++) { 68 sum += data[i] & 0xFF; 69 } 70 return sum; 71 } 72 73 // 1. 32-bit accumulators 74 @Benchmark 75 public int vectorInt() { 76 final var lobyte_mask = IntVector.broadcast(I256, 0x000000FF); 77 78 var acc = IntVector.zero(I256); 79 for (int i = 0; i < data.length; i += B256.length()) { 80 var vb = ByteVector.fromArray(B256, data, i); 81 var vi = (IntVector)vb.reinterpret(I256); 82 for (int j = 0; j < 4; j++) { 83 var tj = vi.shiftR(j * 8).and(lobyte_mask); 84 acc = acc.add(tj); 85 } 86 } 87 return (int)Integer.toUnsignedLong(acc.addAll()); 88 } 89 90 // 2. 16-bit accumulators 91 @Benchmark 92 public int vectorShort() { 93 final var lobyte_mask = ShortVector.broadcast(S256, (short) 0x00FF); 94 95 // FIXME: overflow 96 var acc = ShortVector.zero(S256); 97 for (int i = 0; i < data.length; i += B256.length()) { 98 var vb = ByteVector.fromArray(B256, data, i); 99 var vs = (ShortVector)vb.reinterpret(S256); 100 for (int j = 0; j < 2; j++) { 101 var tj = vs.shiftR(j * 8).and(lobyte_mask); 102 acc = acc.add(tj); 103 } 104 } 105 106 int mid = S128.length(); 107 var accLo = ((IntVector)(acc .reshape(S128).cast(I256))).and(0xFFFF); // low half as ints 108 var accHi = ((IntVector)(acc.shiftEL(mid).reshape(S128).cast(I256))).and(0xFFFF); // high half as ints 109 return accLo.addAll() + accHi.addAll(); 110 } 111 112 /* 113 // 3. 8-bit halves (MISSING: _mm_adds_epu8) 114 @Benchmark 115 public int vectorByte() { 116 int window = 256; 117 var acc_hi = IntVector.zero(I256); 118 var acc8_lo = ByteVector.zero(B256); 119 for (int i = 0; i < data.length; i += window) { 120 var acc8_hi = ByteVector.zero(B256); 121 int limit = Math.min(window, data.length - i); 122 for (int j = 0; j < limit; j += B256.length()) { 123 var vb = ByteVector.fromArray(B256, data, i + j); 124 125 var t0 = acc8_lo.add(vb); 126 var t1 = addSaturated(acc8_lo, vb); // MISSING 127 var overflow = t0.notEqual(t1); 128 129 acc8_lo = t0; 130 acc8_hi = acc8_hi.add((byte) 1, overflow); 131 } 132 acc_hi = acc_hi.add(sum(acc8_hi)); 133 } 134 return sum(acc8_lo) 135 .add(acc_hi.mul(256)) // overflow 136 .addAll(); 137 } 138 139 // 4. Sum Of Absolute Differences (SAD) (MISSING: VPSADBW, _mm256_sad_epu8) 140 public int vectorSAD() { 141 var acc = IntVector.zero(I256); 142 for (int i = 0; i < data.length; i += B256.length()) { 143 var v = ByteVector.fromArray(B256, data, i); 144 var sad = sumOfAbsoluteDifferences(v, ByteVector.zero(B256)); // MISSING 145 acc = acc.add(sad); 146 } 147 return acc.addAll(); 148 } */ 149 150 // Helpers 151 /* 152 static ByteVector addSaturated(ByteVector va, ByteVector vb) { 153 var vc = ByteVector.zero(B256); 154 for (int i = 0; i < B256.length(); i++) { 155 if ((va.get(i) & 0xFF) + (vb.get(i) & 0xFF) < 0xFF) { 156 vc = vc.with(i, (byte)(va.get(i) + vb.get(i))); 157 } else { 158 vc = vc.with(i, (byte)0xFF); 159 } 160 } 161 return vc; 162 } 163 IntVector sumOfAbsoluteDifferences(ByteVector va, ByteVector vb) { 164 var vc = ByteVector.zero(B256); 165 for (int i = 0; i < B256.length(); i++) { 166 if ((va.get(i) & 0xFF) > (vb.get(i) & 0xFF)) { 167 vc = vc.with(i, (byte)(va.get(i) - vb.get(i))); 168 } else { 169 vc = vc.with(i, (byte)(vb.get(i) - va.get(i))); 170 } 171 } 172 return sum(vc); 173 } */ 174 }