1 /*
   2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   3  *
   4  * This code is free software; you can redistribute it and/or modify it
   5  * under the terms of the GNU General Public License version 2 only, as
   6  * published by the Free Software Foundation.
   7  *
   8  * This code is distributed in the hope that it will be useful, but WITHOUT
   9  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  10  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  11  * version 2 for more details (a copy is included in the LICENSE file that
  12  * accompanied this code).
  13  *
  14  * You should have received a copy of the GNU General Public License version
  15  * 2 along with this work; if not, write to the Free Software Foundation,
  16  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  17  *
  18  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  19  * or visit www.oracle.com if you need additional information or have any
  20  * questions.
  21  */
  22 
  23 /*
  24  * This file is available under and governed by the GNU General Public
  25  * License version 2 only, as published by the Free Software Foundation.
  26  * However, the following notice accompanied the original version of this
  27  * file:
  28  *
  29  * Written by Doug Lea with assistance from members of JCP JSR-166
  30  * Expert Group and released to the public domain, as explained at
  31  * http://creativecommons.org/publicdomain/zero/1.0/
  32  */
  33 
  34 /*
  35  * @test
  36  * @bug 6445158
  37  * @summary Basic tests for Phaser
  38  * @author Chris Hegarty
  39  */
  40 
  41 import java.util.Iterator;
  42 import java.util.LinkedList;
  43 import java.util.concurrent.Phaser;
  44 import java.util.concurrent.TimeUnit;
  45 import java.util.concurrent.TimeoutException;
  46 import java.util.concurrent.atomic.AtomicInteger;
  47 import static java.util.concurrent.TimeUnit.*;
  48 
  49 public class Basic {
  50 
  51     private static void checkTerminated(final Phaser phaser) {
  52         check(phaser.isTerminated());
  53         int unarriverParties = phaser.getUnarrivedParties();
  54         int registeredParties = phaser.getRegisteredParties();
  55         int phase = phaser.getPhase();
  56         check(phase < 0);
  57         equal(phase, phaser.arrive());
  58         equal(phase, phaser.arriveAndDeregister());
  59         equal(phase, phaser.arriveAndAwaitAdvance());
  60         equal(phase, phaser.bulkRegister(10));
  61         equal(phase, phaser.register());
  62         try {
  63             equal(phase, phaser.awaitAdvanceInterruptibly(0));
  64             equal(phase, phaser.awaitAdvanceInterruptibly(0, 10, SECONDS));
  65         } catch (Exception ie) {
  66             unexpected(ie);
  67         }
  68         equal(phaser.getUnarrivedParties(), unarriverParties);
  69         equal(phaser.getRegisteredParties(), registeredParties);
  70     }
  71 
  72     private static void checkResult(Arriver a, Class<? extends Throwable> c) {
  73         Throwable t = a.result();
  74         if (! ((t == null && c == null) || (c != null && c.isInstance(t)))) {
  75             //      t.printStackTrace();
  76             fail("Mismatch in thread " +
  77                  a.getName() + ": " +
  78                  t + ", " +
  79                  (c == null ? "<null>" : c.getName()));
  80         } else {
  81             pass();
  82         }
  83     }
  84 
  85     //----------------------------------------------------------------
  86     // Mechanism to get all test threads into "running" mode.
  87     //----------------------------------------------------------------
  88     private static Phaser atTheStartingGate = new Phaser(3);
  89 
  90     private static void toTheStartingGate() {
  91         try {
  92             boolean expectNextPhase = false;
  93             if (atTheStartingGate.getUnarrivedParties() == 1) {
  94                 expectNextPhase = true;
  95             }
  96             int phase = atTheStartingGate.getPhase();
  97             equal(phase, atTheStartingGate.arrive());
  98             int awaitPhase = atTheStartingGate.awaitAdvanceInterruptibly
  99                 (phase, 10, SECONDS);
 100             if (expectNextPhase) check(awaitPhase == (phase + 1));
 101 
 102             pass();
 103         } catch (Throwable t) {
 104             unexpected(t);
 105            // reset(atTheStartingGate);
 106             throw new Error(t);
 107         }
 108     }
 109 
 110     //----------------------------------------------------------------
 111     // Convenience methods for creating threads that call arrive,
 112     // awaitAdvance, arriveAndAwaitAdvance, awaitAdvanceInterruptibly
 113     //----------------------------------------------------------------
 114     private static abstract class Arriver extends Thread {
 115         static AtomicInteger count = new AtomicInteger(1);
 116 
 117         Arriver() {
 118             this("Arriver");
 119         }
 120 
 121         Arriver(String name) {
 122             this.setName(name + ":" + count.getAndIncrement());
 123             this.setDaemon(true);
 124         }
 125 
 126         private volatile Throwable result;
 127         private volatile int phase;
 128         protected void result(Throwable result) { this.result = result; }
 129         public Throwable result() { return this.result; }
 130         protected void phase(int phase) { this.phase = phase; }
 131         public int phase() { return this.phase; }
 132     }
 133 
 134     private static abstract class Awaiter extends Arriver {
 135         Awaiter() { super("Awaiter"); }
 136         Awaiter(String name) { super(name); }
 137     }
 138 
 139     private static Arriver arriver(final Phaser phaser) {
 140         return new Arriver() { public void run() {
 141             toTheStartingGate();
 142 
 143             try { phase(phaser.arrive()); }
 144             catch (Throwable result) { result(result); }}};
 145     }
 146 
 147     private static AtomicInteger cycleArriveAwaitAdvance = new AtomicInteger(1);
 148 
 149     private static Awaiter awaiter(final Phaser phaser) {
 150         return new Awaiter() { public void run() {
 151             toTheStartingGate();
 152 
 153             try {
 154                 if (cycleArriveAwaitAdvance.getAndIncrement() % 2 == 0)
 155                     phase(phaser.awaitAdvance(phaser.arrive()));
 156                 else
 157                     phase(phaser.arriveAndAwaitAdvance());
 158             } catch (Throwable result) { result(result); }}};
 159     }
 160 
 161     private static Awaiter awaiter(final Phaser phaser,
 162                                    final long timeout,
 163                                    final TimeUnit unit) {
 164         return new Awaiter("InterruptibleWaiter") { public void run() {
 165             toTheStartingGate();
 166 
 167             try {
 168                 if (timeout < 0)
 169                     phase(phaser.awaitAdvanceInterruptibly(phaser.arrive()));
 170                 else
 171                     phase(phaser.awaitAdvanceInterruptibly(phaser.arrive(),
 172                                                      timeout,
 173                                                      unit));
 174             } catch (Throwable result) { result(result); }}};
 175     }
 176 
 177     // Returns an infinite lazy list of all possible arriver/awaiter combinations.
 178     private static Iterator<Arriver> arriverIterator(final Phaser phaser) {
 179         return new Iterator<Arriver>() {
 180             int i = 0;
 181             public boolean hasNext() { return true; }
 182             public Arriver next() {
 183                 switch ((i++)&7) {
 184                     case 0: case 4:
 185                         return arriver(phaser);
 186                     case 1: case 5:
 187                         return awaiter(phaser);
 188                     case 2: case 6: case 7:
 189                         return awaiter(phaser, -1, SECONDS);
 190                     default:
 191                         return awaiter(phaser, 10, SECONDS); }}
 192             public void remove() {throw new UnsupportedOperationException();}};
 193     }
 194 
 195     // Returns an infinite lazy list of all possible awaiter only combinations.
 196     private static Iterator<Awaiter> awaiterIterator(final Phaser phaser) {
 197         return new Iterator<Awaiter>() {
 198             int i = 0;
 199             public boolean hasNext() { return true; }
 200             public Awaiter next() {
 201                 switch ((i++)&7) {
 202                     case 1: case 4: case 7:
 203                         return awaiter(phaser);
 204                     case 2: case 5:
 205                         return awaiter(phaser, -1, SECONDS);
 206                     default:
 207                         return awaiter(phaser, 10, SECONDS); }}
 208             public void remove() {throw new UnsupportedOperationException();}};
 209     }
 210 
 211     private static void realMain(String[] args) throws Throwable {
 212 
 213         Thread.currentThread().setName("mainThread");
 214 
 215         //----------------------------------------------------------------
 216         // Normal use
 217         //----------------------------------------------------------------
 218         try {
 219             Phaser phaser = new Phaser(3);
 220             equal(phaser.getRegisteredParties(), 3);
 221             equal(phaser.getArrivedParties(), 0);
 222             equal(phaser.getPhase(), 0);
 223             check(phaser.getRoot().equals(phaser));
 224             equal(phaser.getParent(), null);
 225             check(!phaser.isTerminated());
 226 
 227             Iterator<Arriver> arrivers = arriverIterator(phaser);
 228             int phase = 0;
 229             for (int i = 0; i < 10; i++) {
 230                 equal(phaser.getPhase(), phase++);
 231                 Arriver a1 = arrivers.next(); a1.start();
 232                 Arriver a2 = arrivers.next(); a2.start();
 233                 toTheStartingGate();
 234                 phaser.arriveAndAwaitAdvance();
 235                 a1.join();
 236                 a2.join();
 237                 checkResult(a1, null);
 238                 checkResult(a2, null);
 239                 check(!phaser.isTerminated());
 240                 equal(phaser.getRegisteredParties(), 3);
 241                 equal(phaser.getArrivedParties(), 0);
 242             }
 243         } catch (Throwable t) { unexpected(t); }
 244 
 245         //----------------------------------------------------------------
 246         // One thread interrupted
 247         //----------------------------------------------------------------
 248         try {
 249             Phaser phaser = new Phaser(3);
 250             Iterator<Arriver> arrivers = arriverIterator(phaser);
 251             int phase = phaser.getPhase();
 252             for (int i = 0; i < 4; i++) {
 253                 check(phaser.getPhase() == phase);
 254                 Awaiter a1 = awaiter(phaser, 10, SECONDS); a1.start();
 255                 Arriver a2 = arrivers.next(); a2.start();
 256                 toTheStartingGate();
 257                 a1.interrupt();
 258                 a1.join();
 259                 phaser.arriveAndAwaitAdvance();
 260                 a2.join();
 261                 checkResult(a1, InterruptedException.class);
 262                 checkResult(a2, null);
 263                 check(!phaser.isTerminated());
 264                 equal(phaser.getRegisteredParties(), 3);
 265                 equal(phaser.getArrivedParties(), 0);
 266                 phase++;
 267             }
 268         } catch (Throwable t) { unexpected(t); }
 269 
 270         //----------------------------------------------------------------
 271         // Phaser is terminated while threads are waiting
 272         //----------------------------------------------------------------
 273         try {
 274             for (int i = 0; i < 4; i++) {
 275                 Phaser phaser = new Phaser(3);
 276                 Iterator<Awaiter> awaiters = awaiterIterator(phaser);
 277                 Arriver a1 = awaiters.next(); a1.start();
 278                 Arriver a2 = awaiters.next(); a2.start();
 279                 toTheStartingGate();
 280                 while (phaser.getArrivedParties() < 2) Thread.yield();
 281                 equal(0, phaser.getPhase());
 282                 phaser.forceTermination();
 283                 a1.join();
 284                 a2.join();
 285                 equal(0 + Integer.MIN_VALUE, a1.phase);
 286                 equal(0 + Integer.MIN_VALUE, a2.phase);
 287                 int arrivedParties = phaser.getArrivedParties();
 288                 checkTerminated(phaser);
 289                 equal(phaser.getArrivedParties(), arrivedParties);
 290             }
 291         } catch (Throwable t) { unexpected(t); }
 292 
 293         //----------------------------------------------------------------
 294         // Adds new unarrived parties to this phaser
 295         //----------------------------------------------------------------
 296         try {
 297             Phaser phaser = new Phaser(1);
 298             Iterator<Arriver> arrivers = arriverIterator(phaser);
 299             LinkedList<Arriver> arriverList = new LinkedList<Arriver>();
 300             int phase = phaser.getPhase();
 301             for (int i = 1; i < 5; i++) {
 302                 atTheStartingGate = new Phaser(1+(3*i));
 303                 check(phaser.getPhase() == phase);
 304                 // register 3 more
 305                 phaser.register(); phaser.register(); phaser.register();
 306                 for (int z=0; z<(3*i); z++) {
 307                    arriverList.add(arrivers.next());
 308                 }
 309                 for (Arriver arriver : arriverList)
 310                     arriver.start();
 311 
 312                 toTheStartingGate();
 313                 phaser.arriveAndAwaitAdvance();
 314 
 315                 for (Arriver arriver : arriverList) {
 316                     arriver.join();
 317                     checkResult(arriver, null);
 318                 }
 319                 equal(phaser.getRegisteredParties(), 1 + (3*i));
 320                 equal(phaser.getArrivedParties(), 0);
 321                 arriverList.clear();
 322                 phase++;
 323             }
 324             atTheStartingGate = new Phaser(3);
 325         } catch (Throwable t) { unexpected(t); }
 326 
 327         //----------------------------------------------------------------
 328         // One thread timed out
 329         //----------------------------------------------------------------
 330         try {
 331             Phaser phaser = new Phaser(3);
 332             Iterator<Arriver> arrivers = arriverIterator(phaser);
 333             for (long timeout : new long[] { 0L, 5L }) {
 334                 for (int i = 0; i < 2; i++) {
 335                     Awaiter a1 = awaiter(phaser, timeout, SECONDS); a1.start();
 336                     Arriver a2 = arrivers.next();                   a2.start();
 337                     toTheStartingGate();
 338                     a1.join();
 339                     checkResult(a1, TimeoutException.class);
 340                     phaser.arrive();
 341                     a2.join();
 342                     checkResult(a2, null);
 343                     check(!phaser.isTerminated());
 344                 }
 345             }
 346         } catch (Throwable t) { unexpected(t); }
 347 
 348         //----------------------------------------------------------------
 349         // Barrier action completed normally
 350         //----------------------------------------------------------------
 351         try {
 352             final AtomicInteger count = new AtomicInteger(0);
 353             final Phaser[] kludge = new Phaser[1];
 354             Phaser phaser = new Phaser(3) {
 355                 @Override
 356                 protected boolean onAdvance(int phase, int registeredParties) {
 357                     int countPhase = count.getAndIncrement();
 358                     equal(countPhase, phase);
 359                     equal(kludge[0].getPhase(), phase);
 360                     equal(kludge[0].getRegisteredParties(), registeredParties);
 361                     if (phase >= 3)
 362                         return true; // terminate
 363 
 364                     return false;
 365                 }
 366             };
 367             kludge[0] = phaser;
 368             equal(phaser.getRegisteredParties(), 3);
 369             Iterator<Awaiter> awaiters = awaiterIterator(phaser);
 370             for (int i = 0; i < 4; i++) {
 371                 Awaiter a1 = awaiters.next(); a1.start();
 372                 Awaiter a2 = awaiters.next(); a2.start();
 373                 toTheStartingGate();
 374                 while (phaser.getArrivedParties() < 2) Thread.yield();
 375                 phaser.arrive();
 376                 a1.join();
 377                 a2.join();
 378                 checkResult(a1, null);
 379                 checkResult(a2, null);
 380                 equal(count.get(), i+1);
 381                 if (i < 3) {
 382                     check(!phaser.isTerminated());
 383                     equal(phaser.getRegisteredParties(), 3);
 384                     equal(phaser.getArrivedParties(), 0);
 385                     equal(phaser.getUnarrivedParties(), 3);
 386                     equal(phaser.getPhase(), count.get());
 387                 } else
 388                     checkTerminated(phaser);
 389             }
 390         } catch (Throwable t) { unexpected(t); }
 391 
 392     }
 393 
 394     //--------------------- Infrastructure ---------------------------
 395     static volatile int passed = 0, failed = 0;
 396     static void pass() {passed++;}
 397     static void fail() {failed++; Thread.dumpStack();}
 398     static void fail(String msg) {System.out.println(msg); fail();}
 399     static void unexpected(Throwable t) {failed++; t.printStackTrace();}
 400     static void check(boolean cond) {if (cond) pass(); else fail();}
 401     static void equal(Object x, Object y) {
 402         if (x == null ? y == null : x.equals(y)) pass();
 403         else fail(x + " not equal to " + y);}
 404     public static void main(String[] args) throws Throwable {
 405         try {realMain(args);} catch (Throwable t) {unexpected(t);}
 406         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
 407         if (failed > 0) throw new AssertionError("Some tests failed");}
 408 }