11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have
23 * questions.
24 */
25 package benchmark.jdk.incubator.vector;
26
27 import jdk.incubator.vector.ByteVector;
28 import jdk.incubator.vector.ShortVector;
29 import jdk.incubator.vector.IntVector;
30 import jdk.incubator.vector.LongVector;
31 import jdk.incubator.vector.Vector.Species;
32 import org.openjdk.jmh.annotations.*;
33
34 import java.util.concurrent.TimeUnit;
35
36 import static org.junit.jupiter.api.Assertions.assertEquals;
37
38 /**
39 * Population count algorithms from "Faster Population Counts Using AVX2 Instructions", 2018 by Mula, Kurz, Lemire
40 */
41 @BenchmarkMode(Mode.Throughput)
42 @Warmup(iterations = 3, time = 1)
43 @Measurement(iterations = 5, time = 1)
44 @OutputTimeUnit(TimeUnit.MILLISECONDS)
45 @State(Scope.Benchmark)
46 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
47 public class PopulationCount extends AbstractVectorBenchmark {
48 @Param({"64", "1024", "65536"})
49 int size;
50
51 private long[] data;
333 + 1 * Long.bitCount(ones);
334
335 return total + tail(upper);
336 }
337
338 /* ============================================================================================================== */
339
340 // FIGURE 9. A C function using SSE intrinsics implementing Mula’s algorithm to compute sixteen population counts,
341 // corresponding to sixteen input bytes.
342
343 static final ByteVector MULA128_LOOKUP = (ByteVector)(IntVector.scalars(I128, 0x02_01_01_00, // 0, 1, 1, 2,
344 0x03_02_02_01, // 1, 2, 2, 3,
345 0x03_02_02_01, // 1, 2, 2, 3,
346 0x04_03_03_02 // 2, 3, 3, 4
347 ).reinterpret(B128));
348
349 ByteVector popcntB128(ByteVector v) {
350 var low_mask = ByteVector.broadcast(B128, (byte)0x0f);
351
352 var lo = v .and(low_mask);
353 var hi = v.shiftR(4).and(low_mask);
354
355 var cnt1 = MULA128_LOOKUP.rearrange(lo.toShuffle());
356 var cnt2 = MULA128_LOOKUP.rearrange(hi.toShuffle());
357
358 return cnt1.add(cnt2);
359 }
360
361 @Benchmark
362 public long Mula128() {
363 var acc = LongVector.zero(L128); // IntVector
364 int step = 32; // % B128.length() == 0!
365 int upper = data.length - (data.length % step);
366 for (int i = 0; i < upper; i += step) {
367 var bacc = ByteVector.zero(B128);
368 for (int j = 0; j < step; j += L128.length()) {
369 var v1 = LongVector.fromArray(L128, data, i + j);
370 var v2 = (ByteVector)v1.reinterpret(B128);
371 var v3 = popcntB128(v2);
372 bacc = bacc.add(v3);
373 }
374 acc = acc.add(sumUnsignedBytes(bacc));
375 }
376 var r = acc.addAll() + tail(upper);
377 return r;
378 }
379
380 /* ============================================================================================================== */
381
382 // FIGURE 10. A C function using AVX2 intrinsics implementing Mula’s algorithm to compute the four population counts
383 // of the four 64-bit words in a 256-bit vector. The 32 B output vector should be interpreted as four separate
384 // 64-bit counts that need to be summed to obtain the final population count.
385
386 static final ByteVector MULA256_LOOKUP =
387 (ByteVector)(join(I128, I256, (IntVector)(MULA128_LOOKUP.reinterpret(I128)), (IntVector)(MULA128_LOOKUP.reinterpret(I128))).reinterpret(B256));
388
389 ByteVector popcntB256(ByteVector v) {
390 var low_mask = ByteVector.broadcast(B256, (byte)0x0F);
391
392 var lo = v .and(low_mask);
393 var hi = v.shiftR(4).and(low_mask);
394
395 var cnt1 = MULA256_LOOKUP.rearrange(lo.toShuffle());
396 var cnt2 = MULA256_LOOKUP.rearrange(hi.toShuffle());
397 var cnt = cnt1.add(cnt2);
398
399 return cnt;
400 }
401
402 // Horizontally sum each consecutive 8 differences to produce four unsigned 16-bit integers,
403 // and pack these unsigned 16-bit integers in the low 16 bits of 64-bit elements in dst:
404 // _mm256_sad_epu8(total, _mm256_setzero_si256())
405 LongVector sumUnsignedBytes(ByteVector vb) {
406 return sumUnsignedBytesShapes(vb);
407 // return sumUnsignedBytesShifts(vb);
408 }
409
410 LongVector sumUnsignedBytesShapes(ByteVector vb) {
411 Species<Short> shortSpecies = Species.of(short.class, vb.shape());
412 Species<Integer> intSpecies = Species.of(int.class, vb.shape());
413 Species<Long> longSpecies = Species.of(long.class, vb.shape());
414
415 var low_short_mask = ShortVector.broadcast(shortSpecies, (short) 0xFF);
416 var low_int_mask = IntVector.broadcast(intSpecies, 0xFFFF);
417 var low_long_mask = LongVector.broadcast(longSpecies, 0xFFFFFFFFL);
418
419 var vs = (ShortVector)vb.reinterpret(shortSpecies); // 16-bit
420 var vs0 = vs.and(low_short_mask);
421 var vs1 = vs.shiftR(8).and(low_short_mask);
422 var vs01 = vs0.add(vs1);
423
424 var vi = (IntVector)vs01.reinterpret(intSpecies); // 32-bit
425 var vi0 = vi.and(low_int_mask);
426 var vi1 = vi.shiftR(16).and(low_int_mask);
427 var vi01 = vi0.add(vi1);
428
429 var vl = (LongVector)vi01.reinterpret(longSpecies); // 64-bit
430 var vl0 = vl.and(low_long_mask);
431 var vl1 = vl.shiftR(32).and(low_long_mask);
432 var vl01 = vl0.add(vl1);
433
434 return vl01;
435 }
436
437 LongVector sumUnsignedBytesShifts(ByteVector vb) {
438 Species<Long> to = Species.of(long.class, vb.shape());
439
440 var low_mask = LongVector.broadcast(to, 0xFF);
441
442 var vl = (LongVector)vb.reinterpret(to);
443
444 var v0 = vl .and(low_mask); // 8-bit
445 var v1 = vl.shiftR( 8).and(low_mask); // 8-bit
446 var v2 = vl.shiftR(16).and(low_mask); // 8-bit
447 var v3 = vl.shiftR(24).and(low_mask); // 8-bit
448 var v4 = vl.shiftR(32).and(low_mask); // 8-bit
449 var v5 = vl.shiftR(40).and(low_mask); // 8-bit
450 var v6 = vl.shiftR(48).and(low_mask); // 8-bit
451 var v7 = vl.shiftR(56).and(low_mask); // 8-bit
452
453 var v01 = v0.add(v1);
454 var v23 = v2.add(v3);
455 var v45 = v4.add(v5);
456 var v67 = v6.add(v7);
457
458 var v03 = v01.add(v23);
459 var v47 = v45.add(v67);
460
461 var sum = v03.add(v47); // 64-bit
462 return sum;
463 }
464
465 @Benchmark
466 public long Mula256() {
467 var acc = LongVector.zero(L256);
468 int step = 32; // % B256.length() == 0!
469 int upper = data.length - (data.length % step);
470 for (int i = 0; i < upper; i += step) {
471 var bacc = ByteVector.zero(B256);
472 for (int j = 0; j < step; j += L256.length()) {
473 var v1 = LongVector.fromArray(L256, data, i + j);
474 var v2 = popcntB256((ByteVector)(v1.reinterpret(B256)));
475 bacc = bacc.add(v2);
476 }
477 acc = acc.add(sumUnsignedBytes(bacc));
478 }
479 return acc.addAll() + tail(upper);
480 }
481
482
483 /* ============================================================================================================== */
484
485 // FIGURE 11. A C function using AVX2 intrinsics implementing a bitwise parallel carry-save adder (CSA).
486
487 LongVector csaLow(LongVector a, LongVector b, LongVector c) {
488 var u = a.xor(b);
489 var r = u.xor(c);
490 return r;
491 }
492
493 LongVector csaHigh(LongVector a, LongVector b, LongVector c) {
494 var u = a.xor(b);
495 var ab = a.and(b);
496 var uc = u.and(c);
497 var r = ab.or(uc); // (a & b) | ((a ^ b) & c)
498 return r;
499 }
597
598 // CSA(&eightsB, &fours, fours, foursA, foursB);
599 eightsB = csaHigh(fours, foursA, foursB);
600 fours = csaLow(fours, foursA, foursB);
601
602 // ====================================
603
604 // CSA(&sixteens, &eights, eights, eightsA, eightsB);
605 sixteens = csaHigh(eights, eightsA, eightsB);
606 eights = csaLow(eights, eightsA, eightsB);
607
608 vtotal = vtotal.add(popcntL256(sixteens));
609 }
610
611 vtotal = vtotal.mul(16); // << 4
612 vtotal = vtotal.add(popcntL256(eights).mul(8)); // << 3
613 vtotal = vtotal.add(popcntL256(fours).mul(4)); // << 2
614 vtotal = vtotal.add(popcntL256(twos).mul(2)); // << 1
615 vtotal = vtotal.add(popcntL256(ones)); // << 0
616
617 var total = vtotal.addAll();
618
619 return total + tail(upper);
620 }
621
622 /* ============================================================================================================== */
623
624 // ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
625 // return _mm512_ternarylogic_epi32(c, b, a, 0x96); // vpternlogd
626 // }
627 //
628 // ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
629 // return _mm512_ternarylogic_epi32(c, b, a, 0xe8); // vpternlogd
630 // }
631 }
|
11 * This code is distributed in the hope that it will be useful, but WITHOUT
12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * version 2 for more details (a copy is included in the LICENSE file that
15 * accompanied this code).
16 *
17 * You should have received a copy of the GNU General Public License version
18 * 2 along with this work; if not, write to the Free Software Foundation,
19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20 *
21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22 * or visit www.oracle.com if you need additional information or have
23 * questions.
24 */
25 package benchmark.jdk.incubator.vector;
26
27 import jdk.incubator.vector.ByteVector;
28 import jdk.incubator.vector.ShortVector;
29 import jdk.incubator.vector.IntVector;
30 import jdk.incubator.vector.LongVector;
31 import jdk.incubator.vector.VectorSpecies;
32 import org.openjdk.jmh.annotations.*;
33
34 import java.util.concurrent.TimeUnit;
35
36 import static org.junit.jupiter.api.Assertions.assertEquals;
37
38 /**
39 * Population count algorithms from "Faster Population Counts Using AVX2 Instructions", 2018 by Mula, Kurz, Lemire
40 */
41 @BenchmarkMode(Mode.Throughput)
42 @Warmup(iterations = 3, time = 1)
43 @Measurement(iterations = 5, time = 1)
44 @OutputTimeUnit(TimeUnit.MILLISECONDS)
45 @State(Scope.Benchmark)
46 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
47 public class PopulationCount extends AbstractVectorBenchmark {
48 @Param({"64", "1024", "65536"})
49 int size;
50
51 private long[] data;
333 + 1 * Long.bitCount(ones);
334
335 return total + tail(upper);
336 }
337
338 /* ============================================================================================================== */
339
340 // FIGURE 9. A C function using SSE intrinsics implementing Mula’s algorithm to compute sixteen population counts,
341 // corresponding to sixteen input bytes.
342
343 static final ByteVector MULA128_LOOKUP = (ByteVector)(IntVector.scalars(I128, 0x02_01_01_00, // 0, 1, 1, 2,
344 0x03_02_02_01, // 1, 2, 2, 3,
345 0x03_02_02_01, // 1, 2, 2, 3,
346 0x04_03_03_02 // 2, 3, 3, 4
347 ).reinterpret(B128));
348
349 ByteVector popcntB128(ByteVector v) {
350 var low_mask = ByteVector.broadcast(B128, (byte)0x0f);
351
352 var lo = v .and(low_mask);
353 var hi = v.shiftRight(4).and(low_mask);
354
355 var cnt1 = MULA128_LOOKUP.rearrange(lo.toShuffle());
356 var cnt2 = MULA128_LOOKUP.rearrange(hi.toShuffle());
357
358 return cnt1.add(cnt2);
359 }
360
361 @Benchmark
362 public long Mula128() {
363 var acc = LongVector.zero(L128); // IntVector
364 int step = 32; // % B128.length() == 0!
365 int upper = data.length - (data.length % step);
366 for (int i = 0; i < upper; i += step) {
367 var bacc = ByteVector.zero(B128);
368 for (int j = 0; j < step; j += L128.length()) {
369 var v1 = LongVector.fromArray(L128, data, i + j);
370 var v2 = (ByteVector)v1.reinterpret(B128);
371 var v3 = popcntB128(v2);
372 bacc = bacc.add(v3);
373 }
374 acc = acc.add(sumUnsignedBytes(bacc));
375 }
376 var r = acc.addLanes() + tail(upper);
377 return r;
378 }
379
380 /* ============================================================================================================== */
381
382 // FIGURE 10. A C function using AVX2 intrinsics implementing Mula’s algorithm to compute the four population counts
383 // of the four 64-bit words in a 256-bit vector. The 32 B output vector should be interpreted as four separate
384 // 64-bit counts that need to be summed to obtain the final population count.
385
386 static final ByteVector MULA256_LOOKUP =
387 (ByteVector)(join(I128, I256, (IntVector)(MULA128_LOOKUP.reinterpret(I128)), (IntVector)(MULA128_LOOKUP.reinterpret(I128))).reinterpret(B256));
388
389 ByteVector popcntB256(ByteVector v) {
390 var low_mask = ByteVector.broadcast(B256, (byte)0x0F);
391
392 var lo = v .and(low_mask);
393 var hi = v.shiftRight(4).and(low_mask);
394
395 var cnt1 = MULA256_LOOKUP.rearrange(lo.toShuffle());
396 var cnt2 = MULA256_LOOKUP.rearrange(hi.toShuffle());
397 var cnt = cnt1.add(cnt2);
398
399 return cnt;
400 }
401
402 // Horizontally sum each consecutive 8 differences to produce four unsigned 16-bit integers,
403 // and pack these unsigned 16-bit integers in the low 16 bits of 64-bit elements in dst:
404 // _mm256_sad_epu8(total, _mm256_setzero_si256())
405 LongVector sumUnsignedBytes(ByteVector vb) {
406 return sumUnsignedBytesShapes(vb);
407 // return sumUnsignedBytesShifts(vb);
408 }
409
410 LongVector sumUnsignedBytesShapes(ByteVector vb) {
411 VectorSpecies<Short> shortSpecies = VectorSpecies.of(short.class, vb.shape());
412 VectorSpecies<Integer> intSpecies = VectorSpecies.of(int.class, vb.shape());
413 VectorSpecies<Long> longSpecies = VectorSpecies.of(long.class, vb.shape());
414
415 var low_short_mask = ShortVector.broadcast(shortSpecies, (short) 0xFF);
416 var low_int_mask = IntVector.broadcast(intSpecies, 0xFFFF);
417 var low_long_mask = LongVector.broadcast(longSpecies, 0xFFFFFFFFL);
418
419 var vs = (ShortVector)vb.reinterpret(shortSpecies); // 16-bit
420 var vs0 = vs.and(low_short_mask);
421 var vs1 = vs.shiftRight(8).and(low_short_mask);
422 var vs01 = vs0.add(vs1);
423
424 var vi = (IntVector)vs01.reinterpret(intSpecies); // 32-bit
425 var vi0 = vi.and(low_int_mask);
426 var vi1 = vi.shiftRight(16).and(low_int_mask);
427 var vi01 = vi0.add(vi1);
428
429 var vl = (LongVector)vi01.reinterpret(longSpecies); // 64-bit
430 var vl0 = vl.and(low_long_mask);
431 var vl1 = vl.shiftRight(32).and(low_long_mask);
432 var vl01 = vl0.add(vl1);
433
434 return vl01;
435 }
436
437 LongVector sumUnsignedBytesShifts(ByteVector vb) {
438 VectorSpecies<Long> to = VectorSpecies.of(long.class, vb.shape());
439
440 var low_mask = LongVector.broadcast(to, 0xFF);
441
442 var vl = (LongVector)vb.reinterpret(to);
443
444 var v0 = vl .and(low_mask); // 8-bit
445 var v1 = vl.shiftRight( 8).and(low_mask); // 8-bit
446 var v2 = vl.shiftRight(16).and(low_mask); // 8-bit
447 var v3 = vl.shiftRight(24).and(low_mask); // 8-bit
448 var v4 = vl.shiftRight(32).and(low_mask); // 8-bit
449 var v5 = vl.shiftRight(40).and(low_mask); // 8-bit
450 var v6 = vl.shiftRight(48).and(low_mask); // 8-bit
451 var v7 = vl.shiftRight(56).and(low_mask); // 8-bit
452
453 var v01 = v0.add(v1);
454 var v23 = v2.add(v3);
455 var v45 = v4.add(v5);
456 var v67 = v6.add(v7);
457
458 var v03 = v01.add(v23);
459 var v47 = v45.add(v67);
460
461 var sum = v03.add(v47); // 64-bit
462 return sum;
463 }
464
465 @Benchmark
466 public long Mula256() {
467 var acc = LongVector.zero(L256);
468 int step = 32; // % B256.length() == 0!
469 int upper = data.length - (data.length % step);
470 for (int i = 0; i < upper; i += step) {
471 var bacc = ByteVector.zero(B256);
472 for (int j = 0; j < step; j += L256.length()) {
473 var v1 = LongVector.fromArray(L256, data, i + j);
474 var v2 = popcntB256((ByteVector)(v1.reinterpret(B256)));
475 bacc = bacc.add(v2);
476 }
477 acc = acc.add(sumUnsignedBytes(bacc));
478 }
479 return acc.addLanes() + tail(upper);
480 }
481
482
483 /* ============================================================================================================== */
484
485 // FIGURE 11. A C function using AVX2 intrinsics implementing a bitwise parallel carry-save adder (CSA).
486
487 LongVector csaLow(LongVector a, LongVector b, LongVector c) {
488 var u = a.xor(b);
489 var r = u.xor(c);
490 return r;
491 }
492
493 LongVector csaHigh(LongVector a, LongVector b, LongVector c) {
494 var u = a.xor(b);
495 var ab = a.and(b);
496 var uc = u.and(c);
497 var r = ab.or(uc); // (a & b) | ((a ^ b) & c)
498 return r;
499 }
597
598 // CSA(&eightsB, &fours, fours, foursA, foursB);
599 eightsB = csaHigh(fours, foursA, foursB);
600 fours = csaLow(fours, foursA, foursB);
601
602 // ====================================
603
604 // CSA(&sixteens, &eights, eights, eightsA, eightsB);
605 sixteens = csaHigh(eights, eightsA, eightsB);
606 eights = csaLow(eights, eightsA, eightsB);
607
608 vtotal = vtotal.add(popcntL256(sixteens));
609 }
610
611 vtotal = vtotal.mul(16); // << 4
612 vtotal = vtotal.add(popcntL256(eights).mul(8)); // << 3
613 vtotal = vtotal.add(popcntL256(fours).mul(4)); // << 2
614 vtotal = vtotal.add(popcntL256(twos).mul(2)); // << 1
615 vtotal = vtotal.add(popcntL256(ones)); // << 0
616
617 var total = vtotal.addLanes();
618
619 return total + tail(upper);
620 }
621
622 /* ============================================================================================================== */
623
624 // ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
625 // return _mm512_ternarylogic_epi32(c, b, a, 0x96); // vpternlogd
626 // }
627 //
628 // ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
629 // return _mm512_ternarylogic_epi32(c, b, a, 0xe8); // vpternlogd
630 // }
631 }
|