1 /* 2 * Copyright (c) 2010, 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * 27 * @summary converted from VM Testbase jit/escape/LockElision/MatMul. 28 * VM Testbase keywords: [jit, quick] 29 * VM Testbase readme: 30 * DESCRIPTION 31 * The test multiplies 2 matrices, first, by directly calculating matrix product 32 * elements, and second, by calculating them parallelly in diffenent threads. 33 * The results are compared then. 34 * The test, in addition to required locks, introduces locks on local variables or 35 * variables not escaping from the executing thread, and nests them manifoldly. 36 * In case of a buggy compiler, during lock elimination some code, required by 37 * the calulation may be eliminated as well, or the code may be overoptimized in 38 * some other way, causing difference in the execution results. 39 * The test has one parameter, -dim, which specifies the dimensions of matrices. 40 * 41 * @library /vmTestbase 42 * /test/lib 43 * @run driver jdk.test.lib.FileInstaller . . 44 * @build jit.escape.LockElision.MatMul.MatMul 45 * @run driver ExecDriver --java jit.escape.LockElision.MatMul.MatMul -dim 30 -threadCount 10 46 */ 47 48 package jit.escape.LockElision.MatMul; 49 50 import java.util.*; 51 import java.util.concurrent.CountDownLatch; 52 import java.util.concurrent.ExecutorService; 53 import java.util.concurrent.Executors; 54 55 import nsk.share.Consts; 56 import nsk.share.Log; 57 import nsk.share.Pair; 58 59 import nsk.share.test.StressOptions; 60 import vm.share.options.Option; 61 import vm.share.options.OptionSupport; 62 import vm.share.options.Options; 63 64 65 class MatMul { 66 67 @Option(name = "dim", description = "dimension of matrices") 68 int dim; 69 70 @Option(name = "verbose", default_value = "false", 71 description = "verbose mode") 72 boolean verbose; 73 74 @Option(name = "threadCount", description = "thread count") 75 int threadCount; 76 77 @Options 78 StressOptions stressOptions = new StressOptions(); 79 80 private Log log; 81 82 public static void main(String[] args) { 83 MatMul test = new MatMul(); 84 OptionSupport.setup(test, args); 85 System.exit(Consts.JCK_STATUS_BASE + test.run()); 86 } 87 88 public int run() { 89 log = new Log(System.out, verbose); 90 log.display("Parallel matrix multiplication test"); 91 92 Matrix a = Matrix.randomMatrix(dim); 93 Matrix b = Matrix.randomMatrix(dim); 94 long t1, t2; 95 96 t1 = System.currentTimeMillis(); 97 Matrix serialResult = serialMul(a, b); 98 t2 = System.currentTimeMillis(); 99 log.display("serial time: " + (t2 - t1) + "ms"); 100 101 try { 102 t1 = System.currentTimeMillis(); 103 Matrix parallelResult = parallelMul(a, b, 104 threadCount * stressOptions.getThreadsFactor()); 105 t2 = System.currentTimeMillis(); 106 log.display("parallel time: " + (t2 - t1) + "ms"); 107 108 if (!serialResult.equals(parallelResult)) { 109 log.complain("a = \n" + a); 110 log.complain("b = \n" + b); 111 112 log.complain("serial: a * b = \n" + serialResult); 113 log.complain("serial: a * b = \n" + parallelResult); 114 return Consts.TEST_FAILED; 115 } 116 return Consts.TEST_PASSED; 117 118 } catch (CounterIncorrectStateException e) { 119 log.complain("incorrect state of counter " + e.counter.name); 120 log.complain("expected = " + e.counter.expected); 121 log.complain("actual " + e.counter.state()); 122 return Consts.TEST_FAILED; 123 } 124 } 125 126 public static int convolution(Seq<Integer> one, Seq<Integer> two) { 127 int res = 0; 128 int upperBound = Math.min(one.size(), two.size()); 129 for (int i = 0; i < upperBound; i++) { 130 res += one.get(i) * two.get(i); 131 } 132 return res; 133 } 134 135 /** 136 * calculate chunked convolutuion of two sequences 137 * <p/> 138 * This special version of this method: 139 * <pre>{@code 140 * public static int chunkedConvolution(Seq<Integer> one, Seq<Integer> two, int from, int to) { 141 * int res = 0; 142 * int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1); 143 * for (int i = from; i < upperBound; i++) { 144 * res += one.get(i) * two.get(i); 145 * } 146 * return res; 147 * }}</pre> 148 * <p/> 149 * that tries to fool the Lock Elision optimization: 150 * Most lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed. 151 * But several synchronized blocks (all that protected by Counter instances) are really necessary, and removing them we obtain 152 * an incorrect result. 153 * 154 * @param one 155 * @param two 156 * @param from - lower bound of sum 157 * @param to - upper bound of sum 158 * @param local - reference ThreadLocal that will be used for calculations 159 * @param bCounter - Counter instance, need to perfom checks 160 */ 161 public static int chunkedConvolutionWithDummy(Seq<Integer> one, 162 Seq<Integer> two, int from, int to, ThreadLocals local, 163 Counter bCounter) { 164 ThreadLocals conv_local1 = new ThreadLocals(local, "conv_local1"); 165 ThreadLocals conv_local2 = new ThreadLocals(conv_local1, "conv_local2"); 166 ThreadLocals conv_local3 = new ThreadLocals(null, "conv_local3"); 167 int res = 0; 168 synchronized (local) { 169 local.updateHash(); 170 int upperBound = 0; 171 synchronized (conv_local1) { 172 upperBound = local.min(one.size(), two.size()); 173 synchronized (two) { 174 //int upperBound = Math.min(Math.min(one.size(), two.size()), to + 1) : 175 upperBound = conv_local1.min(upperBound, to + 1); 176 synchronized (bCounter) { 177 bCounter.inc(); 178 } 179 } 180 for (int i = from; i < upperBound; i++) { 181 synchronized (conv_local2) { 182 conv_local1.updateHash(); 183 int prod = 0; 184 synchronized (one) { 185 int t = conv_local2.mult(one.get(i), two.get(i)); 186 synchronized (conv_local3) { 187 prod = t; 188 189 } 190 //res += one.get(i) * two.get(i) 191 res = conv_local3.sum(res, prod); 192 } 193 } 194 } 195 } 196 return res; 197 } 198 } 199 200 public boolean productCheck(Matrix a, Matrix b) { 201 if (a == null || b == null) { 202 log.complain("null matrix!"); 203 return false; 204 } 205 206 if (a.dim != b.dim) { 207 log.complain("matrices dimension are differs"); 208 return false; 209 } 210 return true; 211 } 212 213 public Matrix serialMul(Matrix a, Matrix b) { 214 if (!productCheck(a, b)) { 215 throw new IllegalArgumentException(); 216 } 217 218 Matrix result = Matrix.zeroMatrix(a.dim); 219 for (int i = 0; i < a.dim; i++) { 220 for (int j = 0; j < a.dim; j++) { 221 result.set(i, j, convolution(a.row(i), b.column(j))); 222 } 223 } 224 return result; 225 } 226 227 228 /** 229 * Parallel multiplication of matrices. 230 * <p/> 231 * This special version of this method: 232 * <pre>{@code 233 * public Matrix parallelMul1(final Matrix a, final Matrix b, int threadCount) { 234 * if (!productCheck(a, b)) { 235 * throw new IllegalArgumentException(); 236 * } 237 * final int dim = a.dim; 238 * final Matrix result = Matrix.zeroMatrix(dim); 239 * <p/> 240 * ExecutorService threadPool = Executors.newFixedThreadPool(threadCount); 241 * final CountDownLatch latch = new CountDownLatch(threadCount); 242 * List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1), threadCount); 243 * for (final Pair<Integer, Integer> part : parts) { 244 * threadPool.submit(new Runnable() { 245 * @Override 246 * public void run() { 247 * for (int i = 0; i < dim; i++) { 248 * for (int j = 0; j < dim; j++) { 249 * synchronized (result) { 250 * int from = part.first; 251 * int to = part.second; 252 * result.add(i, j, chunkedConvolution(a.row(i), b.column(j), from, to)); 253 * } 254 * } 255 * } 256 * latch.countDown(); 257 * } 258 * }); 259 * } 260 * <p/> 261 * try { 262 * latch.await(); 263 * } catch (InterruptedException e) { 264 * e.printStackTrace(); 265 * } 266 * threadPool.shutdown(); 267 * return result; 268 * }}</pre> 269 * Lines marked with NOP comments need to fool the Lock Elision optimization: 270 * All lock objects in these lines are really thread local, so related synchronized blocks (dummy blocks) can be removed. 271 * But several synchronized blocks (that are nested in dummy blocks) are really necessary, and removing them we obtain 272 * an incorrect result. 273 * 274 * @param a first operand 275 * @param b second operand 276 * @param threadCount number of threads that will be used for calculations 277 * @return product of matrices a and b 278 */ 279 public Matrix parallelMul(final Matrix a, final Matrix b, int threadCount) 280 throws CounterIncorrectStateException { 281 if (!productCheck(a, b)) { 282 throw new IllegalArgumentException(); 283 } 284 final int dim = a.dim; 285 final Matrix result = Matrix.zeroMatrix(dim); 286 287 ExecutorService threadPool = Executors.newFixedThreadPool(threadCount); 288 final CountDownLatch latch = new CountDownLatch(threadCount); 289 List<Pair<Integer, Integer>> parts = splitInterval(Pair.of(0, dim - 1), 290 threadCount); 291 292 final Counter lCounter1 = new Counter(threadCount, "lCounter1"); 293 final Counter lCounter2 = new Counter(threadCount, "lCounter2"); 294 final Counter lCounter3 = new Counter(threadCount, "lCounter3"); 295 296 final Counter bCounter1 = new Counter(threadCount * dim * dim, 297 "bCounter1"); 298 final Counter bCounter2 = new Counter(threadCount * dim * dim, 299 "bCounter2"); 300 final Counter bCounter3 = new Counter(threadCount * dim * dim, 301 "bCounter3"); 302 303 final Counter[] counters = {lCounter1, lCounter2, lCounter3, 304 bCounter1, bCounter2, bCounter3}; 305 306 final Map<Pair<Integer, Integer>, ThreadLocals> locals1 307 = CollectionsUtils.newHashMap(); 308 final Map<Pair<Integer, Integer>, ThreadLocals> locals2 309 = CollectionsUtils.newHashMap(); 310 final Map<Pair<Integer, Integer>, ThreadLocals> locals3 311 = CollectionsUtils.newHashMap(); 312 313 for (final Pair<Integer, Integer> part : parts) { 314 315 ThreadLocals local1 = new ThreadLocals(null, 316 "locals1[" + part + "]"); 317 ThreadLocals local2 = new ThreadLocals(local1, 318 "locals2[" + part + "]"); 319 ThreadLocals local3 = new ThreadLocals(local2, 320 "locals3[" + part + "]"); 321 322 locals1.put(part, local1); 323 locals2.put(part, local2); 324 locals3.put(part, local3); 325 } 326 327 for (final Pair<Integer, Integer> part : parts) { 328 threadPool.submit(new Runnable() { 329 @Override 330 public void run() { 331 ThreadLocals local1 = locals1.get(part); 332 ThreadLocals local2 = locals2.get(part); 333 ThreadLocals local3 = locals3.get(part); 334 ThreadLocals local4 = locals3.get(part); 335 synchronized (local1) { 336 local1.updateHash(); 337 synchronized (lCounter1) { 338 lCounter1.inc(); 339 } 340 synchronized (lCounter3) { 341 synchronized (local2) { 342 local2.updateHash(); 343 lCounter3.inc(); 344 } 345 } 346 synchronized (new Object()) { 347 synchronized (lCounter2) { 348 lCounter2.inc(); 349 } 350 for (int i = 0; i < dim; i++) { 351 for (int j = 0; j < dim; j++) { 352 synchronized (bCounter1) { 353 synchronized (new Object()) { 354 bCounter1.inc(); 355 } 356 } 357 synchronized (local3) { 358 local3.updateHash(); 359 synchronized (bCounter2) { 360 bCounter2.inc(); 361 } 362 synchronized (result) { 363 local1.updateHash(); 364 synchronized (local2) { 365 local2.updateHash(); 366 int from = part.first; 367 int to = part.second; 368 result.add(i, j, 369 chunkedConvolutionWithDummy( 370 a.row(i), 371 b.column(j), 372 from, to, 373 local4, 374 bCounter3)); 375 } 376 } 377 } 378 } 379 } 380 } 381 } 382 latch.countDown(); 383 } 384 }); 385 } 386 387 try { 388 latch.await(); 389 } catch (InterruptedException e) { 390 e.printStackTrace(); 391 } 392 393 threadPool.shutdown(); 394 for (final Pair<Integer, Integer> part : parts) { 395 log.display( 396 "hash for " + part + " = " + locals1.get(part).getHash()); 397 } 398 399 400 for (Counter counter : counters) { 401 if (!counter.check()) { 402 throw new CounterIncorrectStateException(counter); 403 } 404 } 405 return result; 406 } 407 408 /** 409 * Split interval into parts 410 * 411 * @param interval - pair than encode bounds of interval 412 * @param partCount - count of parts 413 * @return list of pairs than encode bounds of parts 414 */ 415 public static List<Pair<Integer, Integer>> splitInterval( 416 Pair<Integer, Integer> interval, int partCount) { 417 if (partCount == 0) { 418 throw new IllegalArgumentException(); 419 } 420 421 if (partCount == 1) { 422 return CollectionsUtils.asList(interval); 423 } 424 425 int intervalSize = interval.second - interval.first + 1; 426 int partSize = intervalSize / partCount; 427 428 List<Pair<Integer, Integer>> init = splitInterval( 429 Pair.of(interval.first, interval.second - partSize), 430 partCount - 1); 431 Pair<Integer, Integer> lastPart = Pair 432 .of(interval.second - partSize + 1, interval.second); 433 434 return CollectionsUtils.append(init, lastPart); 435 } 436 437 public static class Counter { 438 private int state; 439 440 public final int expected; 441 public final String name; 442 443 public void inc() { 444 state++; 445 } 446 447 public int state() { 448 return state; 449 } 450 451 public boolean check() { 452 return state == expected; 453 } 454 455 public Counter(int expected, String name) { 456 this.expected = expected; 457 this.name = name; 458 } 459 } 460 461 private static class CounterIncorrectStateException extends Exception { 462 public final Counter counter; 463 464 public CounterIncorrectStateException(Counter counter) { 465 this.counter = counter; 466 } 467 } 468 469 private static abstract class Seq<E> implements Iterable<E> { 470 @Override 471 public Iterator<E> iterator() { 472 return new Iterator<E>() { 473 private int p = 0; 474 475 @Override 476 public boolean hasNext() { 477 return p < size(); 478 } 479 480 @Override 481 public E next() { 482 return get(p++); 483 } 484 485 @Override 486 public void remove() { 487 } 488 }; 489 } 490 491 public abstract E get(int i); 492 493 public abstract int size(); 494 } 495 496 private static class CollectionsUtils { 497 498 public static <K, V> Map<K, V> newHashMap() { 499 return new HashMap<K, V>(); 500 } 501 502 public static <E> List<E> newArrayList() { 503 return new ArrayList<E>(); 504 } 505 506 public static <E> List<E> newArrayList(Collection<E> collection) { 507 return new ArrayList<E>(collection); 508 } 509 510 public static <E> List<E> asList(E e) { 511 List<E> result = newArrayList(); 512 result.add(e); 513 return result; 514 } 515 516 public static <E> List<E> append(List<E> init, E last) { 517 List<E> result = newArrayList(init); 518 result.add(last); 519 return result; 520 } 521 } 522 523 private static class Matrix { 524 525 public final int dim; 526 private int[] coeffs; 527 528 private Matrix(int dim) { 529 this.dim = dim; 530 this.coeffs = new int[dim * dim]; 531 } 532 533 public void set(int i, int j, int value) { 534 coeffs[i * dim + j] = value; 535 } 536 537 public void add(int i, int j, int value) { 538 coeffs[i * dim + j] += value; 539 } 540 541 public int get(int i, int j) { 542 return coeffs[i * dim + j]; 543 } 544 545 public Seq<Integer> row(final int i) { 546 return new Seq<Integer>() { 547 @Override 548 public Integer get(int j) { 549 return Matrix.this.get(i, j); 550 } 551 552 @Override 553 public int size() { 554 return Matrix.this.dim; 555 } 556 }; 557 } 558 559 public Seq<Integer> column(final int j) { 560 return new Seq<Integer>() { 561 @Override 562 public Integer get(int i) { 563 return Matrix.this.get(i, j); 564 } 565 566 @Override 567 public int size() { 568 return Matrix.this.dim; 569 } 570 }; 571 } 572 573 @Override 574 public String toString() { 575 StringBuilder builder = new StringBuilder(); 576 for (int i = 0; i < dim; i++) { 577 for (int j = 0; j < dim; j++) { 578 builder.append((j == 0) ? "" : "\t\t"); 579 builder.append(get(i, j)); 580 } 581 builder.append("\n"); 582 } 583 return builder.toString(); 584 } 585 586 @Override 587 public boolean equals(Object other) { 588 if (!(other instanceof Matrix)) { 589 return false; 590 } 591 592 Matrix b = (Matrix) other; 593 if (b.dim != this.dim) { 594 return false; 595 } 596 for (int i = 0; i < dim; i++) { 597 for (int j = 0; j < dim; j++) { 598 if (this.get(i, j) != b.get(i, j)) { 599 return false; 600 } 601 } 602 } 603 return true; 604 } 605 606 private static Random random = new Random(); 607 608 public static Matrix randomMatrix(int dim) { 609 Matrix result = new Matrix(dim); 610 for (int i = 0; i < dim; i++) { 611 for (int j = 0; j < dim; j++) { 612 result.set(i, j, random.nextInt(50)); 613 } 614 } 615 return result; 616 } 617 618 public static Matrix zeroMatrix(int dim) { 619 Matrix result = new Matrix(dim); 620 for (int i = 0; i < dim; i++) { 621 for (int j = 0; j < dim; j++) { 622 result.set(i, j, 0); 623 } 624 } 625 return result; 626 } 627 } 628 629 /** 630 * All instances of this class will be used in thread local context 631 */ 632 private static class ThreadLocals { 633 private static final int HASH_BOUND = 424242; 634 635 private ThreadLocals parent; 636 private int hash = 42; 637 public final String name; 638 639 public ThreadLocals(ThreadLocals parent, String name) { 640 this.parent = parent; 641 this.name = name; 642 } 643 644 public int min(int a, int b) { 645 updateHash(a + b + 1); 646 return Math.min(a, b); 647 } 648 649 public int mult(int a, int b) { 650 updateHash(a + b + 2); 651 return a * b; 652 } 653 654 public int sum(int a, int b) { 655 updateHash(a + b + 3); 656 return a + b; 657 } 658 659 660 public int updateHash() { 661 return updateHash(42); 662 } 663 664 public int updateHash(int data) { 665 hash = (hash + data) % HASH_BOUND; 666 if (parent != null) { 667 hash = parent.updateHash(hash) % HASH_BOUND; 668 } 669 return hash; 670 } 671 672 public int getHash() { 673 return hash; 674 } 675 } 676 }