1 /*
   2  * Copyright (c) 2010, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  *
  27  * @summary converted from VM Testbase jit/escape/LockElision/MatMul.
  28  * VM Testbase keywords: [jit, quick]
  29  * VM Testbase readme:
  30  * DESCRIPTION
  31  *     The test multiplies 2 matrices, first, by directly calculating matrix product
  32  *     elements, and second, by calculating them parallelly in diffenent threads.
  33  *     The results are compared then.
  34  *     The test, in addition to required locks, introduces locks on local variables or
  35  *     variables not escaping from the executing thread, and nests them manifoldly.
  36  *     In case of a buggy compiler, during lock elimination some code, required by
  37  *     the calulation may be eliminated as well, or the code may be overoptimized in
  38  *     some other way, causing difference in the execution results.
  39  *     The test has one parameter, -dim, which specifies the dimensions of matrices.
  40  *
  41  * @library /vmTestbase
  42  *          /test/lib
  43  * @run driver jdk.test.lib.FileInstaller . .
  44  * @build jit.escape.LockElision.MatMul.MatMul
  45  * @run driver ExecDriver --java jit.escape.LockElision.MatMul.MatMul -dim 30 -threadCount 10
  46  */
  47 
  48 package jit.escape.LockElision.MatMul;
  49 
  50 import java.util.*;
  51 import java.util.concurrent.CountDownLatch;
  52 import java.util.concurrent.ExecutorService;
  53 import java.util.concurrent.Executors;
  54 
  55 import nsk.share.Consts;
  56 import nsk.share.Log;
  57 import nsk.share.Pair;
  58 
  59 import nsk.share.test.StressOptions;
  60 import vm.share.options.Option;
  61 import vm.share.options.OptionSupport;
  62 import vm.share.options.Options;
  63 
  64 
  65 class MatMul {
  66 
  67     @Option(name = "dim", description = "dimension of matrices")
  68     int dim;
  69 
  70     @Option(name = "verbose", default_value = "false",
  71             description = "verbose mode")
  72     boolean verbose;
  73 
  74     @Option(name = "threadCount", description = "thread count")
  75     int threadCount;
  76 
  77     @Options
  78     StressOptions stressOptions = new StressOptions();
  79 
  80     private Log log;
  81 
  82     public static void main(String[] args) {
  83         MatMul test = new MatMul();
  84         OptionSupport.setup(test, args);
  85         System.exit(Consts.JCK_STATUS_BASE + test.run());
  86     }
  87 
  88     public int run() {
  89         log = new Log(System.out, verbose);
  90         log.display("Parallel matrix multiplication test");
  91 
  92         Matrix a = Matrix.randomMatrix(dim);
  93         Matrix b = Matrix.randomMatrix(dim);
  94         long t1, t2;
  95 
  96         t1 = System.currentTimeMillis();
  97         Matrix serialResult = serialMul(a, b);
  98         t2 = System.currentTimeMillis();
  99         log.display("serial time: " + (t2 - t1) + "ms");
 100 
 101         try {
 102             t1 = System.currentTimeMillis();
 103             Matrix parallelResult = parallelMul(a, b,
 104                     threadCount * stressOptions.getThreadsFactor());
 105             t2 = System.currentTimeMillis();
 106             log.display("parallel time: " + (t2 - t1) + "ms");
 107 
 108             if (!serialResult.equals(parallelResult)) {
 109                 log.complain("a = \n" + a);
 110                 log.complain("b = \n" + b);
 111 
 112                 log.complain("serial: a * b = \n" + serialResult);
 113                 log.complain("serial: a * b = \n" + parallelResult);
 114                 return Consts.TEST_FAILED;
 115             }
 116             return Consts.TEST_PASSED;
 117 
 118         } catch (CounterIncorrectStateException e) {
 119             log.complain("incorrect state of counter " + e.counter.name);
 120             log.complain("expected = " + e.counter.expected);
 121             log.complain("actual " + e.counter.state());
 122             return Consts.TEST_FAILED;
 123         }
 124     }
 125 
 126     public static int convolution(Seq<Integer> one, Seq<Integer> two) {
 127         int res = 0;
 128         int upperBound = Math.min(one.size(), two.size());
 129         for (int i = 0; i < upperBound; i++) {
 130             res += one.get(i) * two.get(i);
 131         }
 132         return res;
 133     }
 134 
 135     /**
 136      * calculate chunked convolutuion of two sequences
 137      * <p/>
 138      * This special version of this method:
 139      * <pre>{@code
 140      * public static int chunkedConvolution(Seq<Integer> one, Seq<Integer> two, int from, int to) {
 141      * int res = 0;
 142      *  int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1);
 143      *  for (int i = from; i < upperBound; i++) {
 144      *    res += one.get(i) * two.get(i);
 145      *   }
 146      *  return res;
 147      * }}</pre>
 148      * <p/>
 149      * that tries to fool the Lock Elision optimization:
 150      * Most lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed.
 151      * But several synchronized blocks (all that protected by Counter instances) are really necessary, and removing them we obtain
 152      * an incorrect result.
 153      *
 154      * @param one
 155      * @param two
 156      * @param from     - lower bound of sum
 157      * @param to       - upper bound of sum
 158      * @param local    - reference ThreadLocal that will be used for calculations
 159      * @param bCounter - Counter instance, need to perfom checks
 160      */
 161     public static int chunkedConvolutionWithDummy(Seq<Integer> one,
 162             Seq<Integer> two, int from, int to, ThreadLocals local,
 163             Counter bCounter) {
 164         ThreadLocals conv_local1 = new ThreadLocals(local, "conv_local1");
 165         ThreadLocals conv_local2 = new ThreadLocals(conv_local1, "conv_local2");
 166         ThreadLocals conv_local3 = new ThreadLocals(null, "conv_local3");
 167         int res = 0;
 168         synchronized (local) {
 169             local.updateHash();
 170             int upperBound = 0;
 171             synchronized (conv_local1) {
 172                 upperBound = local.min(one.size(), two.size());
 173                 synchronized (two) {
 174                     //int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1) :
 175                     upperBound = conv_local1.min(upperBound, to + 1);
 176                     synchronized (bCounter) {
 177                         bCounter.inc();
 178                     }
 179                 }
 180                 for (int i = from; i < upperBound; i++) {
 181                     synchronized (conv_local2) {
 182                         conv_local1.updateHash();
 183                         int prod = 0;
 184                         synchronized (one) {
 185                             int t = conv_local2.mult(one.get(i), two.get(i));
 186                             synchronized (conv_local3) {
 187                                 prod = t;
 188 
 189                             }
 190                             //res += one.get(i) * two.get(i)
 191                             res = conv_local3.sum(res, prod);
 192                         }
 193                     }
 194                 }
 195             }
 196             return res;
 197         }
 198     }
 199 
 200     public boolean productCheck(Matrix a, Matrix b) {
 201         if (a == null || b == null) {
 202             log.complain("null matrix!");
 203             return false;
 204         }
 205 
 206         if (a.dim != b.dim) {
 207             log.complain("matrices dimension are differs");
 208             return false;
 209         }
 210         return true;
 211     }
 212 
 213     public Matrix serialMul(Matrix a, Matrix b) {
 214         if (!productCheck(a, b)) {
 215             throw new IllegalArgumentException();
 216         }
 217 
 218         Matrix result = Matrix.zeroMatrix(a.dim);
 219         for (int i = 0; i < a.dim; i++) {
 220             for (int j = 0; j < a.dim; j++) {
 221                 result.set(i, j, convolution(a.row(i), b.column(j)));
 222             }
 223         }
 224         return result;
 225     }
 226 
 227 
 228     /**
 229      * Parallel multiplication of matrices.
 230      * <p/>
 231      * This special version of this method:
 232      * <pre>{@code
 233      *  public Matrix parallelMul1(final Matrix a, final Matrix b, int threadCount) {
 234      *   if (!productCheck(a, b)) {
 235      *       throw new IllegalArgumentException();
 236      *   }
 237      *   final int dim = a.dim;
 238      *   final Matrix result = Matrix.zeroMatrix(dim);
 239      * <p/>
 240      *   ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
 241      *   final CountDownLatch latch = new CountDownLatch(threadCount);
 242      *   List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1), threadCount);
 243      *   for (final Pair<Integer, Integer> part : parts) {
 244      *       threadPool.submit(new Runnable() {
 245      *           @Override
 246      *           public void run() {
 247      *               for (int i = 0; i < dim; i++) {
 248      *                   for (int j = 0; j < dim; j++) {
 249      *                       synchronized (result) {
 250      *                           int from = part.first;
 251      *                           int to = part.second;
 252      *                           result.add(i, j, chunkedConvolution(a.row(i), b.column(j), from, to));
 253      *                       }
 254      *                   }
 255      *               }
 256      *               latch.countDown();
 257      *           }
 258      *       });
 259      *   }
 260      * <p/>
 261      *   try {
 262      *       latch.await();
 263      *   } catch (InterruptedException e) {
 264      *       e.printStackTrace();
 265      *   }
 266      *   threadPool.shutdown();
 267      *   return result;
 268      * }}</pre>
 269      * Lines marked with NOP comments need to fool the Lock Elision optimization:
 270      * All lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed.
 271      * But several synchronized blocks (that are nested in dummy blocks) are really necessary, and removing them we obtain
 272      * an incorrect result.
 273      *
 274      * @param a           first operand
 275      * @param b           second operand
 276      * @param threadCount number of threads that will be used for calculations
 277      * @return product of matrices a and b
 278      */
 279     public Matrix parallelMul(final Matrix a, final Matrix b, int threadCount)
 280             throws CounterIncorrectStateException {
 281         if (!productCheck(a, b)) {
 282             throw new IllegalArgumentException();
 283         }
 284         final int dim = a.dim;
 285         final Matrix result = Matrix.zeroMatrix(dim);
 286 
 287         ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
 288         final CountDownLatch latch = new CountDownLatch(threadCount);
 289         List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1),
 290                 threadCount);
 291 
 292         final Counter lCounter1 = new Counter(threadCount, "lCounter1");
 293         final Counter lCounter2 = new Counter(threadCount, "lCounter2");
 294         final Counter lCounter3 = new Counter(threadCount, "lCounter3");
 295 
 296         final Counter bCounter1 = new Counter(threadCount * dim * dim,
 297                 "bCounter1");
 298         final Counter bCounter2 = new Counter(threadCount * dim * dim,
 299                 "bCounter2");
 300         final Counter bCounter3 = new Counter(threadCount * dim * dim,
 301                 "bCounter3");
 302 
 303         final Counter[] counters = {lCounter1, lCounter2, lCounter3,
 304                 bCounter1, bCounter2, bCounter3};
 305 
 306         final Map<Pair<Integer, Integer>, ThreadLocals> locals1
 307                 = CollectionsUtils.newHashMap();
 308         final Map<Pair<Integer, Integer>, ThreadLocals> locals2
 309                 = CollectionsUtils.newHashMap();
 310         final Map<Pair<Integer, Integer>, ThreadLocals> locals3
 311                 = CollectionsUtils.newHashMap();
 312 
 313         for (final Pair<Integer, Integer> part : parts) {
 314 
 315             ThreadLocals local1 = new ThreadLocals(null,
 316                     "locals1[" + part + "]");
 317             ThreadLocals local2 = new ThreadLocals(local1,
 318                     "locals2[" + part + "]");
 319             ThreadLocals local3 = new ThreadLocals(local2,
 320                     "locals3[" + part + "]");
 321 
 322             locals1.put(part, local1);
 323             locals2.put(part, local2);
 324             locals3.put(part, local3);
 325         }
 326 
 327         for (final Pair<Integer, Integer> part : parts) {
 328             threadPool.submit(new Runnable() {
 329                 @Override
 330                 public void run() {
 331                     ThreadLocals local1 = locals1.get(part);
 332                     ThreadLocals local2 = locals2.get(part);
 333                     ThreadLocals local3 = locals3.get(part);
 334                     ThreadLocals local4 = locals3.get(part);
 335                     synchronized (local1) {
 336                         local1.updateHash();
 337                         synchronized (lCounter1) {
 338                             lCounter1.inc();
 339                         }
 340                         synchronized (lCounter3) {
 341                             synchronized (local2) {
 342                                 local2.updateHash();
 343                                 lCounter3.inc();
 344                             }
 345                         }
 346                         synchronized (new Object()) {
 347                             synchronized (lCounter2) {
 348                                 lCounter2.inc();
 349                             }
 350                             for (int i = 0; i < dim; i++) {
 351                                 for (int j = 0; j < dim; j++) {
 352                                     synchronized (bCounter1) {
 353                                         synchronized (new Object()) {
 354                                             bCounter1.inc();
 355                                         }
 356                                     }
 357                                     synchronized (local3) {
 358                                         local3.updateHash();
 359                                         synchronized (bCounter2) {
 360                                             bCounter2.inc();
 361                                         }
 362                                         synchronized (result) {
 363                                             local1.updateHash();
 364                                             synchronized (local2) {
 365                                                 local2.updateHash();
 366                                                 int from = part.first;
 367                                                 int to = part.second;
 368                                                 result.add(i, j,
 369                                                         chunkedConvolutionWithDummy(
 370                                                                 a.row(i),
 371                                                                 b.column(j),
 372                                                                 from, to,
 373                                                                 local4,
 374                                                                 bCounter3));
 375                                             }
 376                                         }
 377                                     }
 378                                 }
 379                             }
 380                         }
 381                     }
 382                     latch.countDown();
 383                 }
 384             });
 385         }
 386 
 387         try {
 388             latch.await();
 389         } catch (InterruptedException e) {
 390             e.printStackTrace();
 391         }
 392 
 393         threadPool.shutdown();
 394         for (final Pair<Integer, Integer> part : parts) {
 395             log.display(
 396                     "hash for " + part + " = " + locals1.get(part).getHash());
 397         }
 398 
 399 
 400         for (Counter counter : counters) {
 401             if (!counter.check()) {
 402                 throw new CounterIncorrectStateException(counter);
 403             }
 404         }
 405         return result;
 406     }
 407 
 408     /**
 409      * Split interval into parts
 410      *
 411      * @param interval  - pair than encode bounds of interval
 412      * @param partCount - count of parts
 413      * @return list of pairs than encode bounds of parts
 414      */
 415     public static List<Pair<Integer, Integer>> splitInterval(
 416             Pair<Integer, Integer> interval, int partCount) {
 417         if (partCount == 0) {
 418             throw new IllegalArgumentException();
 419         }
 420 
 421         if (partCount == 1) {
 422             return CollectionsUtils.asList(interval);
 423         }
 424 
 425         int intervalSize = interval.second - interval.first + 1;
 426         int partSize = intervalSize / partCount;
 427 
 428         List<Pair<Integer, Integer>> init = splitInterval(
 429                 Pair.of(interval.first, interval.second - partSize),
 430                 partCount - 1);
 431         Pair<Integer, Integer> lastPart = Pair
 432                 .of(interval.second - partSize + 1, interval.second);
 433 
 434         return CollectionsUtils.append(init, lastPart);
 435     }
 436 
 437     public static class Counter {
 438         private int state;
 439 
 440         public final int expected;
 441         public final String name;
 442 
 443         public void inc() {
 444             state++;
 445         }
 446 
 447         public int state() {
 448             return state;
 449         }
 450 
 451         public boolean check() {
 452             return state == expected;
 453         }
 454 
 455         public Counter(int expected, String name) {
 456             this.expected = expected;
 457             this.name = name;
 458         }
 459     }
 460 
 461     private static class CounterIncorrectStateException extends Exception {
 462         public final Counter counter;
 463 
 464         public CounterIncorrectStateException(Counter counter) {
 465             this.counter = counter;
 466         }
 467     }
 468 
 469     private static abstract class Seq<E> implements Iterable<E> {
 470         @Override
 471         public Iterator<E> iterator() {
 472             return new Iterator<E>() {
 473                 private int p = 0;
 474 
 475                 @Override
 476                 public boolean hasNext() {
 477                     return p < size();
 478                 }
 479 
 480                 @Override
 481                 public E next() {
 482                     return get(p++);
 483                 }
 484 
 485                 @Override
 486                 public void remove() {
 487                 }
 488             };
 489         }
 490 
 491         public abstract E get(int i);
 492 
 493         public abstract int size();
 494     }
 495 
 496     private static class CollectionsUtils {
 497 
 498         public static <K, V> Map<K, V> newHashMap() {
 499             return new HashMap<K, V>();
 500         }
 501 
 502         public static <E> List<E> newArrayList() {
 503             return new ArrayList<E>();
 504         }
 505 
 506         public static <E> List<E> newArrayList(Collection<E> collection) {
 507             return new ArrayList<E>(collection);
 508         }
 509 
 510         public static <E> List<E> asList(E e) {
 511             List<E> result = newArrayList();
 512             result.add(e);
 513             return result;
 514         }
 515 
 516         public static <E> List<E> append(List<E> init, E last) {
 517             List<E> result = newArrayList(init);
 518             result.add(last);
 519             return result;
 520         }
 521     }
 522 
 523     private static class Matrix {
 524 
 525         public final int dim;
 526         private int[] coeffs;
 527 
 528         private Matrix(int dim) {
 529             this.dim = dim;
 530             this.coeffs = new int[dim * dim];
 531         }
 532 
 533         public void set(int i, int j, int value) {
 534             coeffs[i * dim + j] = value;
 535         }
 536 
 537         public void add(int i, int j, int value) {
 538             coeffs[i * dim + j] += value;
 539         }
 540 
 541         public int get(int i, int j) {
 542             return coeffs[i * dim + j];
 543         }
 544 
 545         public Seq<Integer> row(final int i) {
 546             return new Seq<Integer>() {
 547                 @Override
 548                 public Integer get(int j) {
 549                     return Matrix.this.get(i, j);
 550                 }
 551 
 552                 @Override
 553                 public int size() {
 554                     return Matrix.this.dim;
 555                 }
 556             };
 557         }
 558 
 559         public Seq<Integer> column(final int j) {
 560             return new Seq<Integer>() {
 561                 @Override
 562                 public Integer get(int i) {
 563                     return Matrix.this.get(i, j);
 564                 }
 565 
 566                 @Override
 567                 public int size() {
 568                     return Matrix.this.dim;
 569                 }
 570             };
 571         }
 572 
 573         @Override
 574         public String toString() {
 575             StringBuilder builder = new StringBuilder();
 576             for (int i = 0; i < dim; i++) {
 577                 for (int j = 0; j < dim; j++) {
 578                     builder.append((j == 0) ? "" : "\t\t");
 579                     builder.append(get(i, j));
 580                 }
 581                 builder.append("\n");
 582             }
 583             return builder.toString();
 584         }
 585 
 586         @Override
 587         public boolean equals(Object other) {
 588             if (!(other instanceof Matrix)) {
 589                 return false;
 590             }
 591 
 592             Matrix b = (Matrix) other;
 593             if (b.dim != this.dim) {
 594                 return false;
 595             }
 596             for (int i = 0; i < dim; i++) {
 597                 for (int j = 0; j < dim; j++) {
 598                     if (this.get(i, j) != b.get(i, j)) {
 599                         return false;
 600                     }
 601                 }
 602             }
 603             return true;
 604         }
 605 
 606         private static Random random = new Random();
 607 
 608         public static Matrix randomMatrix(int dim) {
 609             Matrix result = new Matrix(dim);
 610             for (int i = 0; i < dim; i++) {
 611                 for (int j = 0; j < dim; j++) {
 612                     result.set(i, j, random.nextInt(50));
 613                 }
 614             }
 615             return result;
 616         }
 617 
 618         public static Matrix zeroMatrix(int dim) {
 619             Matrix result = new Matrix(dim);
 620             for (int i = 0; i < dim; i++) {
 621                 for (int j = 0; j < dim; j++) {
 622                     result.set(i, j, 0);
 623                 }
 624             }
 625             return result;
 626         }
 627     }
 628 
 629     /**
 630      * All instances of this class will be used in thread local context
 631      */
 632     private static class ThreadLocals {
 633         private static final int HASH_BOUND = 424242;
 634 
 635         private ThreadLocals parent;
 636         private int hash = 42;
 637         public final String name;
 638 
 639         public ThreadLocals(ThreadLocals parent, String name) {
 640             this.parent = parent;
 641             this.name = name;
 642         }
 643 
 644         public int min(int a, int b) {
 645             updateHash(a + b + 1);
 646             return Math.min(a, b);
 647         }
 648 
 649         public int mult(int a, int b) {
 650             updateHash(a + b + 2);
 651             return a * b;
 652         }
 653 
 654         public int sum(int a, int b) {
 655             updateHash(a + b + 3);
 656             return a + b;
 657         }
 658 
 659 
 660         public int updateHash() {
 661             return updateHash(42);
 662         }
 663 
 664         public int updateHash(int data) {
 665             hash = (hash + data) % HASH_BOUND;
 666             if (parent != null) {
 667                 hash = parent.updateHash(hash) % HASH_BOUND;
 668             }
 669             return hash;
 670         }
 671 
 672         public int getHash() {
 673             return hash;
 674         }
 675     }
 676 }