< prev index next >

test/jdk/jdk/incubator/vector/benchmark/src/main/java/benchmark/jdk/incubator/vector/PopulationCount.java

Print this page
rev 55894 : 8222897: [vector] Renaming of shift, rotate operations. Few other api changes.
Summary: Renaming of shift, rotate operations. Few other api changes.
Reviewed-by: jrose, briangoetz


  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have
  23  * questions.
  24  */
  25 package benchmark.jdk.incubator.vector;
  26 
  27 import jdk.incubator.vector.ByteVector;
  28 import jdk.incubator.vector.ShortVector;
  29 import jdk.incubator.vector.IntVector;
  30 import jdk.incubator.vector.LongVector;
  31 import jdk.incubator.vector.Vector.Species;
  32 import org.openjdk.jmh.annotations.*;
  33 
  34 import java.util.concurrent.TimeUnit;
  35 
  36 import static org.junit.jupiter.api.Assertions.assertEquals;
  37 
  38 /**
  39  * Population count algorithms from "Faster Population Counts Using AVX2 Instructions", 2018 by Mula, Kurz, Lemire
  40  */
  41 @BenchmarkMode(Mode.Throughput)
  42 @Warmup(iterations = 3, time = 1)
  43 @Measurement(iterations = 5, time = 1)
  44 @OutputTimeUnit(TimeUnit.MILLISECONDS)
  45 @State(Scope.Benchmark)
  46 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
  47 public class PopulationCount extends AbstractVectorBenchmark {
  48     @Param({"64", "1024", "65536"})
  49     int size;
  50 
  51     private long[] data;


 333                 + 1 * Long.bitCount(ones);
 334 
 335         return total + tail(upper);
 336     }
 337 
 338     /* ============================================================================================================== */
 339 
 340     // FIGURE 9. A C function using SSE intrinsics implementing Mula’s algorithm to compute sixteen population counts,
 341     // corresponding to sixteen input bytes.
 342 
 343     static final ByteVector MULA128_LOOKUP = (ByteVector)(IntVector.scalars(I128, 0x02_01_01_00, // 0, 1, 1, 2,
 344                                                           0x03_02_02_01, // 1, 2, 2, 3,
 345                                                           0x03_02_02_01, // 1, 2, 2, 3,
 346                                                           0x04_03_03_02  // 2, 3, 3, 4
 347                                                           ).reinterpret(B128));
 348 
 349     ByteVector popcntB128(ByteVector v) {
 350         var low_mask = ByteVector.broadcast(B128, (byte)0x0f);
 351 
 352         var lo = v          .and(low_mask);
 353         var hi = v.shiftR(4).and(low_mask);
 354 
 355         var cnt1 = MULA128_LOOKUP.rearrange(lo.toShuffle());
 356         var cnt2 = MULA128_LOOKUP.rearrange(hi.toShuffle());
 357 
 358         return cnt1.add(cnt2);
 359     }
 360 
 361     @Benchmark
 362     public long Mula128() {
 363         var acc = LongVector.zero(L128); // IntVector
 364         int step = 32; // % B128.length() == 0!
 365         int upper = data.length - (data.length % step);
 366         for (int i = 0; i < upper; i += step) {
 367             var bacc = ByteVector.zero(B128);
 368             for (int j = 0; j < step; j += L128.length()) {
 369                 var v1 = LongVector.fromArray(L128, data, i + j);
 370                 var v2 = (ByteVector)v1.reinterpret(B128);
 371                 var v3 = popcntB128(v2);
 372                 bacc = bacc.add(v3);
 373             }
 374             acc = acc.add(sumUnsignedBytes(bacc));
 375         }
 376         var r = acc.addAll() + tail(upper);
 377         return r;
 378     }
 379 
 380     /* ============================================================================================================== */
 381 
 382     // FIGURE 10. A C function using AVX2 intrinsics implementing Mula’s algorithm to compute the four population counts
 383     // of the four 64-bit words in a 256-bit vector. The 32 B output vector should be interpreted as four separate
 384     // 64-bit counts that need to be summed to obtain the final population count.
 385 
 386     static final ByteVector MULA256_LOOKUP = 
 387             (ByteVector)(join(I128, I256, (IntVector)(MULA128_LOOKUP.reinterpret(I128)), (IntVector)(MULA128_LOOKUP.reinterpret(I128))).reinterpret(B256));
 388 
 389     ByteVector popcntB256(ByteVector v) {
 390         var low_mask = ByteVector.broadcast(B256, (byte)0x0F);
 391 
 392         var lo = v          .and(low_mask);
 393         var hi = v.shiftR(4).and(low_mask);
 394 
 395         var cnt1 = MULA256_LOOKUP.rearrange(lo.toShuffle());
 396         var cnt2 = MULA256_LOOKUP.rearrange(hi.toShuffle());
 397         var cnt = cnt1.add(cnt2);
 398 
 399         return cnt;
 400     }
 401 
 402     // Horizontally sum each consecutive 8 differences to produce four unsigned 16-bit integers,
 403     // and pack these unsigned 16-bit integers in the low 16 bits of 64-bit elements in dst:
 404     //   _mm256_sad_epu8(total, _mm256_setzero_si256())
 405     LongVector sumUnsignedBytes(ByteVector vb) {
 406         return sumUnsignedBytesShapes(vb);
 407 //        return sumUnsignedBytesShifts(vb);
 408     }
 409 
 410     LongVector sumUnsignedBytesShapes(ByteVector vb) {
 411         Species<Short> shortSpecies = Species.of(short.class, vb.shape());
 412         Species<Integer> intSpecies = Species.of(int.class, vb.shape());
 413         Species<Long> longSpecies = Species.of(long.class, vb.shape());
 414 
 415         var low_short_mask = ShortVector.broadcast(shortSpecies, (short) 0xFF);
 416         var low_int_mask = IntVector.broadcast(intSpecies, 0xFFFF);
 417         var low_long_mask = LongVector.broadcast(longSpecies, 0xFFFFFFFFL);
 418 
 419         var vs = (ShortVector)vb.reinterpret(shortSpecies); // 16-bit
 420         var vs0 = vs.and(low_short_mask);
 421         var vs1 = vs.shiftR(8).and(low_short_mask);
 422         var vs01 = vs0.add(vs1);
 423 
 424         var vi = (IntVector)vs01.reinterpret(intSpecies); // 32-bit
 425         var vi0 = vi.and(low_int_mask);
 426         var vi1 = vi.shiftR(16).and(low_int_mask);
 427         var vi01 = vi0.add(vi1);
 428 
 429         var vl = (LongVector)vi01.reinterpret(longSpecies); // 64-bit
 430         var vl0 = vl.and(low_long_mask);
 431         var vl1 = vl.shiftR(32).and(low_long_mask);
 432         var vl01 = vl0.add(vl1);
 433 
 434         return vl01;
 435     }
 436 
 437     LongVector sumUnsignedBytesShifts(ByteVector vb) {
 438         Species<Long> to = Species.of(long.class, vb.shape());
 439 
 440         var low_mask = LongVector.broadcast(to, 0xFF);
 441 
 442         var vl = (LongVector)vb.reinterpret(to);
 443 
 444         var v0 = vl           .and(low_mask); // 8-bit
 445         var v1 = vl.shiftR( 8).and(low_mask); // 8-bit
 446         var v2 = vl.shiftR(16).and(low_mask); // 8-bit
 447         var v3 = vl.shiftR(24).and(low_mask); // 8-bit
 448         var v4 = vl.shiftR(32).and(low_mask); // 8-bit
 449         var v5 = vl.shiftR(40).and(low_mask); // 8-bit
 450         var v6 = vl.shiftR(48).and(low_mask); // 8-bit
 451         var v7 = vl.shiftR(56).and(low_mask); // 8-bit
 452 
 453         var v01 = v0.add(v1);
 454         var v23 = v2.add(v3);
 455         var v45 = v4.add(v5);
 456         var v67 = v6.add(v7);
 457 
 458         var v03 = v01.add(v23);
 459         var v47 = v45.add(v67);
 460 
 461         var sum = v03.add(v47); // 64-bit
 462         return sum;
 463     }
 464 
 465     @Benchmark
 466     public long Mula256() {
 467         var acc = LongVector.zero(L256);
 468         int step = 32; // % B256.length() == 0!
 469         int upper = data.length - (data.length % step);
 470         for (int i = 0; i < upper; i += step) {
 471             var bacc = ByteVector.zero(B256);
 472             for (int j = 0; j < step; j += L256.length()) {
 473                 var v1 = LongVector.fromArray(L256, data, i + j);
 474                 var v2 = popcntB256((ByteVector)(v1.reinterpret(B256)));
 475                 bacc = bacc.add(v2);
 476             }
 477             acc = acc.add(sumUnsignedBytes(bacc));
 478         }
 479         return acc.addAll() + tail(upper);
 480     }
 481 
 482 
 483     /* ============================================================================================================== */
 484 
 485     // FIGURE 11. A C function using AVX2 intrinsics implementing a bitwise parallel carry-save adder (CSA).
 486 
 487     LongVector csaLow(LongVector a, LongVector b, LongVector c) {
 488         var u = a.xor(b);
 489         var r = u.xor(c);
 490         return r;
 491     }
 492 
 493     LongVector csaHigh(LongVector a, LongVector b, LongVector c) {
 494         var u  = a.xor(b);
 495         var ab = a.and(b);
 496         var uc = u.and(c);
 497         var r  = ab.or(uc); // (a & b) | ((a ^ b) & c)
 498         return r;
 499     }


 597 
 598             // CSA(&eightsB, &fours, fours, foursA, foursB);
 599             eightsB = csaHigh(fours, foursA, foursB);
 600             fours   = csaLow(fours, foursA, foursB);
 601 
 602             // ====================================
 603 
 604             // CSA(&sixteens, &eights, eights, eightsA, eightsB);
 605             sixteens = csaHigh(eights, eightsA, eightsB);
 606             eights   = csaLow(eights, eightsA, eightsB);
 607 
 608             vtotal = vtotal.add(popcntL256(sixteens));
 609         }
 610 
 611         vtotal = vtotal.mul(16);                       // << 4
 612         vtotal = vtotal.add(popcntL256(eights).mul(8)); // << 3
 613         vtotal = vtotal.add(popcntL256(fours).mul(4));  // << 2
 614         vtotal = vtotal.add(popcntL256(twos).mul(2));   // << 1
 615         vtotal = vtotal.add(popcntL256(ones));          // << 0
 616 
 617         var total = vtotal.addAll();
 618 
 619         return total + tail(upper);
 620     }
 621 
 622     /* ============================================================================================================== */
 623 
 624 //    ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
 625 //        return _mm512_ternarylogic_epi32(c, b, a, 0x96); // vpternlogd
 626 //    }
 627 //
 628 //    ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
 629 //        return _mm512_ternarylogic_epi32(c, b, a, 0xe8); // vpternlogd
 630 //    }
 631 }


  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have
  23  * questions.
  24  */
  25 package benchmark.jdk.incubator.vector;
  26 
  27 import jdk.incubator.vector.ByteVector;
  28 import jdk.incubator.vector.ShortVector;
  29 import jdk.incubator.vector.IntVector;
  30 import jdk.incubator.vector.LongVector;
  31 import jdk.incubator.vector.VectorSpecies;
  32 import org.openjdk.jmh.annotations.*;
  33 
  34 import java.util.concurrent.TimeUnit;
  35 
  36 import static org.junit.jupiter.api.Assertions.assertEquals;
  37 
  38 /**
  39  * Population count algorithms from "Faster Population Counts Using AVX2 Instructions", 2018 by Mula, Kurz, Lemire
  40  */
  41 @BenchmarkMode(Mode.Throughput)
  42 @Warmup(iterations = 3, time = 1)
  43 @Measurement(iterations = 5, time = 1)
  44 @OutputTimeUnit(TimeUnit.MILLISECONDS)
  45 @State(Scope.Benchmark)
  46 @Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
  47 public class PopulationCount extends AbstractVectorBenchmark {
  48     @Param({"64", "1024", "65536"})
  49     int size;
  50 
  51     private long[] data;


 333                 + 1 * Long.bitCount(ones);
 334 
 335         return total + tail(upper);
 336     }
 337 
 338     /* ============================================================================================================== */
 339 
 340     // FIGURE 9. A C function using SSE intrinsics implementing Mula’s algorithm to compute sixteen population counts,
 341     // corresponding to sixteen input bytes.
 342 
 343     static final ByteVector MULA128_LOOKUP = (ByteVector)(IntVector.scalars(I128, 0x02_01_01_00, // 0, 1, 1, 2,
 344                                                           0x03_02_02_01, // 1, 2, 2, 3,
 345                                                           0x03_02_02_01, // 1, 2, 2, 3,
 346                                                           0x04_03_03_02  // 2, 3, 3, 4
 347                                                           ).reinterpret(B128));
 348 
 349     ByteVector popcntB128(ByteVector v) {
 350         var low_mask = ByteVector.broadcast(B128, (byte)0x0f);
 351 
 352         var lo = v          .and(low_mask);
 353         var hi = v.shiftRight(4).and(low_mask);
 354 
 355         var cnt1 = MULA128_LOOKUP.rearrange(lo.toShuffle());
 356         var cnt2 = MULA128_LOOKUP.rearrange(hi.toShuffle());
 357 
 358         return cnt1.add(cnt2);
 359     }
 360 
 361     @Benchmark
 362     public long Mula128() {
 363         var acc = LongVector.zero(L128); // IntVector
 364         int step = 32; // % B128.length() == 0!
 365         int upper = data.length - (data.length % step);
 366         for (int i = 0; i < upper; i += step) {
 367             var bacc = ByteVector.zero(B128);
 368             for (int j = 0; j < step; j += L128.length()) {
 369                 var v1 = LongVector.fromArray(L128, data, i + j);
 370                 var v2 = (ByteVector)v1.reinterpret(B128);
 371                 var v3 = popcntB128(v2);
 372                 bacc = bacc.add(v3);
 373             }
 374             acc = acc.add(sumUnsignedBytes(bacc));
 375         }
 376         var r = acc.addLanes() + tail(upper);
 377         return r;
 378     }
 379 
 380     /* ============================================================================================================== */
 381 
 382     // FIGURE 10. A C function using AVX2 intrinsics implementing Mula’s algorithm to compute the four population counts
 383     // of the four 64-bit words in a 256-bit vector. The 32 B output vector should be interpreted as four separate
 384     // 64-bit counts that need to be summed to obtain the final population count.
 385 
 386     static final ByteVector MULA256_LOOKUP = 
 387             (ByteVector)(join(I128, I256, (IntVector)(MULA128_LOOKUP.reinterpret(I128)), (IntVector)(MULA128_LOOKUP.reinterpret(I128))).reinterpret(B256));
 388 
 389     ByteVector popcntB256(ByteVector v) {
 390         var low_mask = ByteVector.broadcast(B256, (byte)0x0F);
 391 
 392         var lo = v          .and(low_mask);
 393         var hi = v.shiftRight(4).and(low_mask);
 394 
 395         var cnt1 = MULA256_LOOKUP.rearrange(lo.toShuffle());
 396         var cnt2 = MULA256_LOOKUP.rearrange(hi.toShuffle());
 397         var cnt = cnt1.add(cnt2);
 398 
 399         return cnt;
 400     }
 401 
 402     // Horizontally sum each consecutive 8 differences to produce four unsigned 16-bit integers,
 403     // and pack these unsigned 16-bit integers in the low 16 bits of 64-bit elements in dst:
 404     //   _mm256_sad_epu8(total, _mm256_setzero_si256())
 405     LongVector sumUnsignedBytes(ByteVector vb) {
 406         return sumUnsignedBytesShapes(vb);
 407 //        return sumUnsignedBytesShifts(vb);
 408     }
 409 
 410     LongVector sumUnsignedBytesShapes(ByteVector vb) {
 411         VectorSpecies<Short> shortSpecies = VectorSpecies.of(short.class, vb.shape());
 412         VectorSpecies<Integer> intSpecies = VectorSpecies.of(int.class, vb.shape());
 413         VectorSpecies<Long> longSpecies = VectorSpecies.of(long.class, vb.shape());
 414 
 415         var low_short_mask = ShortVector.broadcast(shortSpecies, (short) 0xFF);
 416         var low_int_mask = IntVector.broadcast(intSpecies, 0xFFFF);
 417         var low_long_mask = LongVector.broadcast(longSpecies, 0xFFFFFFFFL);
 418 
 419         var vs = (ShortVector)vb.reinterpret(shortSpecies); // 16-bit
 420         var vs0 = vs.and(low_short_mask);
 421         var vs1 = vs.shiftRight(8).and(low_short_mask);
 422         var vs01 = vs0.add(vs1);
 423 
 424         var vi = (IntVector)vs01.reinterpret(intSpecies); // 32-bit
 425         var vi0 = vi.and(low_int_mask);
 426         var vi1 = vi.shiftRight(16).and(low_int_mask);
 427         var vi01 = vi0.add(vi1);
 428 
 429         var vl = (LongVector)vi01.reinterpret(longSpecies); // 64-bit
 430         var vl0 = vl.and(low_long_mask);
 431         var vl1 = vl.shiftRight(32).and(low_long_mask);
 432         var vl01 = vl0.add(vl1);
 433 
 434         return vl01;
 435     }
 436 
 437     LongVector sumUnsignedBytesShifts(ByteVector vb) {
 438         VectorSpecies<Long> to = VectorSpecies.of(long.class, vb.shape());
 439 
 440         var low_mask = LongVector.broadcast(to, 0xFF);
 441 
 442         var vl = (LongVector)vb.reinterpret(to);
 443 
 444         var v0 = vl           .and(low_mask); // 8-bit
 445         var v1 = vl.shiftRight( 8).and(low_mask); // 8-bit
 446         var v2 = vl.shiftRight(16).and(low_mask); // 8-bit
 447         var v3 = vl.shiftRight(24).and(low_mask); // 8-bit
 448         var v4 = vl.shiftRight(32).and(low_mask); // 8-bit
 449         var v5 = vl.shiftRight(40).and(low_mask); // 8-bit
 450         var v6 = vl.shiftRight(48).and(low_mask); // 8-bit
 451         var v7 = vl.shiftRight(56).and(low_mask); // 8-bit
 452 
 453         var v01 = v0.add(v1);
 454         var v23 = v2.add(v3);
 455         var v45 = v4.add(v5);
 456         var v67 = v6.add(v7);
 457 
 458         var v03 = v01.add(v23);
 459         var v47 = v45.add(v67);
 460 
 461         var sum = v03.add(v47); // 64-bit
 462         return sum;
 463     }
 464 
 465     @Benchmark
 466     public long Mula256() {
 467         var acc = LongVector.zero(L256);
 468         int step = 32; // % B256.length() == 0!
 469         int upper = data.length - (data.length % step);
 470         for (int i = 0; i < upper; i += step) {
 471             var bacc = ByteVector.zero(B256);
 472             for (int j = 0; j < step; j += L256.length()) {
 473                 var v1 = LongVector.fromArray(L256, data, i + j);
 474                 var v2 = popcntB256((ByteVector)(v1.reinterpret(B256)));
 475                 bacc = bacc.add(v2);
 476             }
 477             acc = acc.add(sumUnsignedBytes(bacc));
 478         }
 479         return acc.addLanes() + tail(upper);
 480     }
 481 
 482 
 483     /* ============================================================================================================== */
 484 
 485     // FIGURE 11. A C function using AVX2 intrinsics implementing a bitwise parallel carry-save adder (CSA).
 486 
 487     LongVector csaLow(LongVector a, LongVector b, LongVector c) {
 488         var u = a.xor(b);
 489         var r = u.xor(c);
 490         return r;
 491     }
 492 
 493     LongVector csaHigh(LongVector a, LongVector b, LongVector c) {
 494         var u  = a.xor(b);
 495         var ab = a.and(b);
 496         var uc = u.and(c);
 497         var r  = ab.or(uc); // (a & b) | ((a ^ b) & c)
 498         return r;
 499     }


 597 
 598             // CSA(&eightsB, &fours, fours, foursA, foursB);
 599             eightsB = csaHigh(fours, foursA, foursB);
 600             fours   = csaLow(fours, foursA, foursB);
 601 
 602             // ====================================
 603 
 604             // CSA(&sixteens, &eights, eights, eightsA, eightsB);
 605             sixteens = csaHigh(eights, eightsA, eightsB);
 606             eights   = csaLow(eights, eightsA, eightsB);
 607 
 608             vtotal = vtotal.add(popcntL256(sixteens));
 609         }
 610 
 611         vtotal = vtotal.mul(16);                       // << 4
 612         vtotal = vtotal.add(popcntL256(eights).mul(8)); // << 3
 613         vtotal = vtotal.add(popcntL256(fours).mul(4));  // << 2
 614         vtotal = vtotal.add(popcntL256(twos).mul(2));   // << 1
 615         vtotal = vtotal.add(popcntL256(ones));          // << 0
 616 
 617         var total = vtotal.addLanes();
 618 
 619         return total + tail(upper);
 620     }
 621 
 622     /* ============================================================================================================== */
 623 
 624 //    ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
 625 //        return _mm512_ternarylogic_epi32(c, b, a, 0x96); // vpternlogd
 626 //    }
 627 //
 628 //    ByteVector csaLow512(ByteVector a, ByteVector b, ByteVector c) {
 629 //        return _mm512_ternarylogic_epi32(c, b, a, 0xe8); // vpternlogd
 630 //    }
 631 }
< prev index next >