1 /* 2 * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 package java.util; 26 27 /* 28 * Written by Doug Lea with assistance from members of JCP JSR-166 29 * Expert Group and released to the public domain, as explained at 30 * http://creativecommons.org/publicdomain/zero/1.0/ 31 */ 32 33 import java.util.concurrent.ForkJoinPool; 34 import java.util.concurrent.CountedCompleter; 35 import java.util.function.BinaryOperator; 36 import java.util.function.IntBinaryOperator; 37 import java.util.function.LongBinaryOperator; 38 import java.util.function.DoubleBinaryOperator; 39 40 /** 41 * ForkJoin tasks to perform Arrays.parallelPrefix operations. 42 * 43 * @author Doug Lea 44 * @since 1.8 45 */ 46 class ArrayPrefixHelpers { 47 private ArrayPrefixHelpers() {}; // non-instantiable 48 49 /* 50 * Parallel prefix (aka cumulate, scan) task classes 51 * are based loosely on Guy Blelloch's original 52 * algorithm (http://www.cs.cmu.edu/~scandal/alg/scan.html): 53 * Keep dividing by two to threshold segment size, and then: 54 * Pass 1: Create tree of partial sums for each segment 55 * Pass 2: For each segment, cumulate with offset of left sibling 56 * 57 * This version improves performance within FJ framework mainly by 58 * allowing the second pass of ready left-hand sides to proceed 59 * even if some right-hand side first passes are still executing. 60 * It also combines first and second pass for leftmost segment, 61 * and skips the first pass for rightmost segment (whose result is 62 * not needed for second pass). It similarly manages to avoid 63 * requiring that users supply an identity basis for accumulations 64 * by tracking those segments/subtasks for which the first 65 * existing element is used as base. 66 * 67 * Managing this relies on ORing some bits in the pendingCount for 68 * phases/states: CUMULATE, SUMMED, and FINISHED. CUMULATE is the 69 * main phase bit. When false, segments compute only their sum. 70 * When true, they cumulate array elements. CUMULATE is set at 71 * root at beginning of second pass and then propagated down. But 72 * it may also be set earlier for subtrees with lo==0 (the left 73 * spine of tree). SUMMED is a one bit join count. For leafs, it 74 * is set when summed. For internal nodes, it becomes true when 75 * one child is summed. When the second child finishes summing, 76 * we then moves up tree to trigger the cumulate phase. FINISHED 77 * is also a one bit join count. For leafs, it is set when 78 * cumulated. For internal nodes, it becomes true when one child 79 * is cumulated. When the second child finishes cumulating, it 80 * then moves up tree, completing at the root. 81 * 82 * To better exploit locality and reduce overhead, the compute 83 * method loops starting with the current task, moving if possible 84 * to one of its subtasks rather than forking. 85 * 86 * As usual for this sort of utility, there are 4 versions, that 87 * are simple copy/paste/adapt variants of each other. (The 88 * double and int versions differ from long version soley by 89 * replacing "long" (with case-matching)). 90 */ 91 92 // see above 93 static final int CUMULATE = 1; 94 static final int SUMMED = 2; 95 static final int FINISHED = 4; 96 97 /** The smallest subtask array partition size to use as threshold */ 98 static final int MIN_PARTITION = 16; 99 100 static final class CumulateTask<T> extends CountedCompleter<Void> { 101 final T[] array; 102 final BinaryOperator<T> function; 103 CumulateTask<T> left, right; 104 T in, out; 105 final int lo, hi, origin, fence, threshold; 106 107 /** Root task constructor */ 108 public CumulateTask(CumulateTask<T> parent, 109 BinaryOperator<T> function, 110 T[] array, int lo, int hi) { 111 super(parent); 112 this.function = function; this.array = array; 113 this.lo = this.origin = lo; this.hi = this.fence = hi; 114 int p; 115 this.threshold = 116 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 117 <= MIN_PARTITION ? MIN_PARTITION : p; 118 } 119 120 /** Subtask constructor */ 121 CumulateTask(CumulateTask<T> parent, BinaryOperator<T> function, 122 T[] array, int origin, int fence, int threshold, 123 int lo, int hi) { 124 super(parent); 125 this.function = function; this.array = array; 126 this.origin = origin; this.fence = fence; 127 this.threshold = threshold; 128 this.lo = lo; this.hi = hi; 129 } 130 131 public final void compute() { 132 final BinaryOperator<T> fn; 133 final T[] a; 134 if ((fn = this.function) == null || (a = this.array) == null) 135 throw new NullPointerException(); // hoist checks 136 int th = threshold, org = origin, fnc = fence, l, h; 137 CumulateTask<T> t = this; 138 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 139 if (h - l > th) { 140 CumulateTask<T> lt = t.left, rt = t.right, f; 141 if (lt == null) { // first pass 142 int mid = (l + h) >>> 1; 143 f = rt = t.right = 144 new CumulateTask<T>(t, fn, a, org, fnc, th, mid, h); 145 t = lt = t.left = 146 new CumulateTask<T>(t, fn, a, org, fnc, th, l, mid); 147 } 148 else { // possibly refork 149 T pin = t.in; 150 lt.in = pin; 151 f = t = null; 152 if (rt != null) { 153 T lout = lt.out; 154 rt.in = (l == org ? lout : 155 fn.apply(pin, lout)); 156 for (int c;;) { 157 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 158 break; 159 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 160 t = rt; 161 break; 162 } 163 } 164 } 165 for (int c;;) { 166 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 167 break; 168 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 169 if (t != null) 170 f = t; 171 t = lt; 172 break; 173 } 174 } 175 if (t == null) 176 break; 177 } 178 if (f != null) 179 f.fork(); 180 } 181 else { 182 int state; // Transition to sum, cumulate, or both 183 for (int b;;) { 184 if (((b = t.getPendingCount()) & FINISHED) != 0) 185 break outer; // already done 186 state = ((b & CUMULATE) != 0? FINISHED : 187 (l > org) ? SUMMED : (SUMMED|FINISHED)); 188 if (t.compareAndSetPendingCount(b, b|state)) 189 break; 190 } 191 192 T sum; 193 if (state != SUMMED) { 194 int first; 195 if (l == org) { // leftmost; no in 196 sum = a[org]; 197 first = org + 1; 198 } 199 else { 200 sum = t.in; 201 first = l; 202 } 203 for (int i = first; i < h; ++i) // cumulate 204 a[i] = sum = fn.apply(sum, a[i]); 205 } 206 else if (h < fnc) { // skip rightmost 207 sum = a[l]; 208 for (int i = l + 1; i < h; ++i) // sum only 209 sum = fn.apply(sum, a[i]); 210 } 211 else 212 sum = t.in; 213 t.out = sum; 214 for (CumulateTask<T> par;;) { // propagate 215 if ((par = (CumulateTask<T>)t.getCompleter()) == null) { 216 if ((state & FINISHED) != 0) // enable join 217 t.quietlyComplete(); 218 break outer; 219 } 220 int b = par.getPendingCount(); 221 if ((b & state & FINISHED) != 0) 222 t = par; // both done 223 else if ((b & state & SUMMED) != 0) { // both summed 224 int nextState; CumulateTask<T> lt, rt; 225 if ((lt = par.left) != null && 226 (rt = par.right) != null) { 227 T lout = lt.out; 228 par.out = (rt.hi == fnc ? lout : 229 fn.apply(lout, rt.out)); 230 } 231 int refork = (((b & CUMULATE) == 0 && 232 par.lo == org) ? CUMULATE : 0); 233 if ((nextState = b|state|refork) == b || 234 par.compareAndSetPendingCount(b, nextState)) { 235 state = SUMMED; // drop finished 236 t = par; 237 if (refork != 0) 238 par.fork(); 239 } 240 } 241 else if (par.compareAndSetPendingCount(b, b|state)) 242 break outer; // sib not ready 243 } 244 } 245 } 246 } 247 } 248 249 static final class LongCumulateTask extends CountedCompleter<Void> { 250 final long[] array; 251 final LongBinaryOperator function; 252 LongCumulateTask left, right; 253 long in, out; 254 final int lo, hi, origin, fence, threshold; 255 256 /** Root task constructor */ 257 public LongCumulateTask(LongCumulateTask parent, 258 LongBinaryOperator function, 259 long[] array, int lo, int hi) { 260 super(parent); 261 this.function = function; this.array = array; 262 this.lo = this.origin = lo; this.hi = this.fence = hi; 263 int p; 264 this.threshold = 265 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 266 <= MIN_PARTITION ? MIN_PARTITION : p; 267 } 268 269 /** Subtask constructor */ 270 LongCumulateTask(LongCumulateTask parent, LongBinaryOperator function, 271 long[] array, int origin, int fence, int threshold, 272 int lo, int hi) { 273 super(parent); 274 this.function = function; this.array = array; 275 this.origin = origin; this.fence = fence; 276 this.threshold = threshold; 277 this.lo = lo; this.hi = hi; 278 } 279 280 public final void compute() { 281 final LongBinaryOperator fn; 282 final long[] a; 283 if ((fn = this.function) == null || (a = this.array) == null) 284 throw new NullPointerException(); // hoist checks 285 int th = threshold, org = origin, fnc = fence, l, h; 286 LongCumulateTask t = this; 287 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 288 if (h - l > th) { 289 LongCumulateTask lt = t.left, rt = t.right, f; 290 if (lt == null) { // first pass 291 int mid = (l + h) >>> 1; 292 f = rt = t.right = 293 new LongCumulateTask(t, fn, a, org, fnc, th, mid, h); 294 t = lt = t.left = 295 new LongCumulateTask(t, fn, a, org, fnc, th, l, mid); 296 } 297 else { // possibly refork 298 long pin = t.in; 299 lt.in = pin; 300 f = t = null; 301 if (rt != null) { 302 long lout = lt.out; 303 rt.in = (l == org ? lout : 304 fn.applyAsLong(pin, lout)); 305 for (int c;;) { 306 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 307 break; 308 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 309 t = rt; 310 break; 311 } 312 } 313 } 314 for (int c;;) { 315 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 316 break; 317 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 318 if (t != null) 319 f = t; 320 t = lt; 321 break; 322 } 323 } 324 if (t == null) 325 break; 326 } 327 if (f != null) 328 f.fork(); 329 } 330 else { 331 int state; // Transition to sum, cumulate, or both 332 for (int b;;) { 333 if (((b = t.getPendingCount()) & FINISHED) != 0) 334 break outer; // already done 335 state = ((b & CUMULATE) != 0? FINISHED : 336 (l > org) ? SUMMED : (SUMMED|FINISHED)); 337 if (t.compareAndSetPendingCount(b, b|state)) 338 break; 339 } 340 341 long sum; 342 if (state != SUMMED) { 343 int first; 344 if (l == org) { // leftmost; no in 345 sum = a[org]; 346 first = org + 1; 347 } 348 else { 349 sum = t.in; 350 first = l; 351 } 352 for (int i = first; i < h; ++i) // cumulate 353 a[i] = sum = fn.applyAsLong(sum, a[i]); 354 } 355 else if (h < fnc) { // skip rightmost 356 sum = a[l]; 357 for (int i = l + 1; i < h; ++i) // sum only 358 sum = fn.applyAsLong(sum, a[i]); 359 } 360 else 361 sum = t.in; 362 t.out = sum; 363 for (LongCumulateTask par;;) { // propagate 364 if ((par = (LongCumulateTask)t.getCompleter()) == null) { 365 if ((state & FINISHED) != 0) // enable join 366 t.quietlyComplete(); 367 break outer; 368 } 369 int b = par.getPendingCount(); 370 if ((b & state & FINISHED) != 0) 371 t = par; // both done 372 else if ((b & state & SUMMED) != 0) { // both summed 373 int nextState; LongCumulateTask lt, rt; 374 if ((lt = par.left) != null && 375 (rt = par.right) != null) { 376 long lout = lt.out; 377 par.out = (rt.hi == fnc ? lout : 378 fn.applyAsLong(lout, rt.out)); 379 } 380 int refork = (((b & CUMULATE) == 0 && 381 par.lo == org) ? CUMULATE : 0); 382 if ((nextState = b|state|refork) == b || 383 par.compareAndSetPendingCount(b, nextState)) { 384 state = SUMMED; // drop finished 385 t = par; 386 if (refork != 0) 387 par.fork(); 388 } 389 } 390 else if (par.compareAndSetPendingCount(b, b|state)) 391 break outer; // sib not ready 392 } 393 } 394 } 395 } 396 } 397 398 static final class DoubleCumulateTask extends CountedCompleter<Void> { 399 final double[] array; 400 final DoubleBinaryOperator function; 401 DoubleCumulateTask left, right; 402 double in, out; 403 final int lo, hi, origin, fence, threshold; 404 405 /** Root task constructor */ 406 public DoubleCumulateTask(DoubleCumulateTask parent, 407 DoubleBinaryOperator function, 408 double[] array, int lo, int hi) { 409 super(parent); 410 this.function = function; this.array = array; 411 this.lo = this.origin = lo; this.hi = this.fence = hi; 412 int p; 413 this.threshold = 414 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 415 <= MIN_PARTITION ? MIN_PARTITION : p; 416 } 417 418 /** Subtask constructor */ 419 DoubleCumulateTask(DoubleCumulateTask parent, DoubleBinaryOperator function, 420 double[] array, int origin, int fence, int threshold, 421 int lo, int hi) { 422 super(parent); 423 this.function = function; this.array = array; 424 this.origin = origin; this.fence = fence; 425 this.threshold = threshold; 426 this.lo = lo; this.hi = hi; 427 } 428 429 public final void compute() { 430 final DoubleBinaryOperator fn; 431 final double[] a; 432 if ((fn = this.function) == null || (a = this.array) == null) 433 throw new NullPointerException(); // hoist checks 434 int th = threshold, org = origin, fnc = fence, l, h; 435 DoubleCumulateTask t = this; 436 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 437 if (h - l > th) { 438 DoubleCumulateTask lt = t.left, rt = t.right, f; 439 if (lt == null) { // first pass 440 int mid = (l + h) >>> 1; 441 f = rt = t.right = 442 new DoubleCumulateTask(t, fn, a, org, fnc, th, mid, h); 443 t = lt = t.left = 444 new DoubleCumulateTask(t, fn, a, org, fnc, th, l, mid); 445 } 446 else { // possibly refork 447 double pin = t.in; 448 lt.in = pin; 449 f = t = null; 450 if (rt != null) { 451 double lout = lt.out; 452 rt.in = (l == org ? lout : 453 fn.applyAsDouble(pin, lout)); 454 for (int c;;) { 455 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 456 break; 457 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 458 t = rt; 459 break; 460 } 461 } 462 } 463 for (int c;;) { 464 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 465 break; 466 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 467 if (t != null) 468 f = t; 469 t = lt; 470 break; 471 } 472 } 473 if (t == null) 474 break; 475 } 476 if (f != null) 477 f.fork(); 478 } 479 else { 480 int state; // Transition to sum, cumulate, or both 481 for (int b;;) { 482 if (((b = t.getPendingCount()) & FINISHED) != 0) 483 break outer; // already done 484 state = ((b & CUMULATE) != 0? FINISHED : 485 (l > org) ? SUMMED : (SUMMED|FINISHED)); 486 if (t.compareAndSetPendingCount(b, b|state)) 487 break; 488 } 489 490 double sum; 491 if (state != SUMMED) { 492 int first; 493 if (l == org) { // leftmost; no in 494 sum = a[org]; 495 first = org + 1; 496 } 497 else { 498 sum = t.in; 499 first = l; 500 } 501 for (int i = first; i < h; ++i) // cumulate 502 a[i] = sum = fn.applyAsDouble(sum, a[i]); 503 } 504 else if (h < fnc) { // skip rightmost 505 sum = a[l]; 506 for (int i = l + 1; i < h; ++i) // sum only 507 sum = fn.applyAsDouble(sum, a[i]); 508 } 509 else 510 sum = t.in; 511 t.out = sum; 512 for (DoubleCumulateTask par;;) { // propagate 513 if ((par = (DoubleCumulateTask)t.getCompleter()) == null) { 514 if ((state & FINISHED) != 0) // enable join 515 t.quietlyComplete(); 516 break outer; 517 } 518 int b = par.getPendingCount(); 519 if ((b & state & FINISHED) != 0) 520 t = par; // both done 521 else if ((b & state & SUMMED) != 0) { // both summed 522 int nextState; DoubleCumulateTask lt, rt; 523 if ((lt = par.left) != null && 524 (rt = par.right) != null) { 525 double lout = lt.out; 526 par.out = (rt.hi == fnc ? lout : 527 fn.applyAsDouble(lout, rt.out)); 528 } 529 int refork = (((b & CUMULATE) == 0 && 530 par.lo == org) ? CUMULATE : 0); 531 if ((nextState = b|state|refork) == b || 532 par.compareAndSetPendingCount(b, nextState)) { 533 state = SUMMED; // drop finished 534 t = par; 535 if (refork != 0) 536 par.fork(); 537 } 538 } 539 else if (par.compareAndSetPendingCount(b, b|state)) 540 break outer; // sib not ready 541 } 542 } 543 } 544 } 545 } 546 547 static final class IntCumulateTask extends CountedCompleter<Void> { 548 final int[] array; 549 final IntBinaryOperator function; 550 IntCumulateTask left, right; 551 int in, out; 552 final int lo, hi, origin, fence, threshold; 553 554 /** Root task constructor */ 555 public IntCumulateTask(IntCumulateTask parent, 556 IntBinaryOperator function, 557 int[] array, int lo, int hi) { 558 super(parent); 559 this.function = function; this.array = array; 560 this.lo = this.origin = lo; this.hi = this.fence = hi; 561 int p; 562 this.threshold = 563 (p = (hi - lo) / (ForkJoinPool.getCommonPoolParallelism() << 3)) 564 <= MIN_PARTITION ? MIN_PARTITION : p; 565 } 566 567 /** Subtask constructor */ 568 IntCumulateTask(IntCumulateTask parent, IntBinaryOperator function, 569 int[] array, int origin, int fence, int threshold, 570 int lo, int hi) { 571 super(parent); 572 this.function = function; this.array = array; 573 this.origin = origin; this.fence = fence; 574 this.threshold = threshold; 575 this.lo = lo; this.hi = hi; 576 } 577 578 public final void compute() { 579 final IntBinaryOperator fn; 580 final int[] a; 581 if ((fn = this.function) == null || (a = this.array) == null) 582 throw new NullPointerException(); // hoist checks 583 int th = threshold, org = origin, fnc = fence, l, h; 584 IntCumulateTask t = this; 585 outer: while ((l = t.lo) >= 0 && (h = t.hi) <= a.length) { 586 if (h - l > th) { 587 IntCumulateTask lt = t.left, rt = t.right, f; 588 if (lt == null) { // first pass 589 int mid = (l + h) >>> 1; 590 f = rt = t.right = 591 new IntCumulateTask(t, fn, a, org, fnc, th, mid, h); 592 t = lt = t.left = 593 new IntCumulateTask(t, fn, a, org, fnc, th, l, mid); 594 } 595 else { // possibly refork 596 int pin = t.in; 597 lt.in = pin; 598 f = t = null; 599 if (rt != null) { 600 int lout = lt.out; 601 rt.in = (l == org ? lout : 602 fn.applyAsInt(pin, lout)); 603 for (int c;;) { 604 if (((c = rt.getPendingCount()) & CUMULATE) != 0) 605 break; 606 if (rt.compareAndSetPendingCount(c, c|CUMULATE)){ 607 t = rt; 608 break; 609 } 610 } 611 } 612 for (int c;;) { 613 if (((c = lt.getPendingCount()) & CUMULATE) != 0) 614 break; 615 if (lt.compareAndSetPendingCount(c, c|CUMULATE)) { 616 if (t != null) 617 f = t; 618 t = lt; 619 break; 620 } 621 } 622 if (t == null) 623 break; 624 } 625 if (f != null) 626 f.fork(); 627 } 628 else { 629 int state; // Transition to sum, cumulate, or both 630 for (int b;;) { 631 if (((b = t.getPendingCount()) & FINISHED) != 0) 632 break outer; // already done 633 state = ((b & CUMULATE) != 0? FINISHED : 634 (l > org) ? SUMMED : (SUMMED|FINISHED)); 635 if (t.compareAndSetPendingCount(b, b|state)) 636 break; 637 } 638 639 int sum; 640 if (state != SUMMED) { 641 int first; 642 if (l == org) { // leftmost; no in 643 sum = a[org]; 644 first = org + 1; 645 } 646 else { 647 sum = t.in; 648 first = l; 649 } 650 for (int i = first; i < h; ++i) // cumulate 651 a[i] = sum = fn.applyAsInt(sum, a[i]); 652 } 653 else if (h < fnc) { // skip rightmost 654 sum = a[l]; 655 for (int i = l + 1; i < h; ++i) // sum only 656 sum = fn.applyAsInt(sum, a[i]); 657 } 658 else 659 sum = t.in; 660 t.out = sum; 661 for (IntCumulateTask par;;) { // propagate 662 if ((par = (IntCumulateTask)t.getCompleter()) == null) { 663 if ((state & FINISHED) != 0) // enable join 664 t.quietlyComplete(); 665 break outer; 666 } 667 int b = par.getPendingCount(); 668 if ((b & state & FINISHED) != 0) 669 t = par; // both done 670 else if ((b & state & SUMMED) != 0) { // both summed 671 int nextState; IntCumulateTask lt, rt; 672 if ((lt = par.left) != null && 673 (rt = par.right) != null) { 674 int lout = lt.out; 675 par.out = (rt.hi == fnc ? lout : 676 fn.applyAsInt(lout, rt.out)); 677 } 678 int refork = (((b & CUMULATE) == 0 && 679 par.lo == org) ? CUMULATE : 0); 680 if ((nextState = b|state|refork) == b || 681 par.compareAndSetPendingCount(b, nextState)) { 682 state = SUMMED; // drop finished 683 t = par; 684 if (refork != 0) 685 par.fork(); 686 } 687 } 688 else if (par.compareAndSetPendingCount(b, b|state)) 689 break outer; // sib not ready 690 } 691 } 692 } 693 } 694 } 695 696