1 /*
   2  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   3  *
   4  * This code is free software; you can redistribute it and/or modify it
   5  * under the terms of the GNU General Public License version 2 only, as
   6  * published by the Free Software Foundation.
   7  *
   8  * This code is distributed in the hope that it will be useful, but WITHOUT
   9  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  10  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  11  * version 2 for more details (a copy is included in the LICENSE file that
  12  * accompanied this code).
  13  *
  14  * You should have received a copy of the GNU General Public License version
  15  * 2 along with this work; if not, write to the Free Software Foundation,
  16  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  17  *
  18  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  19  * or visit www.oracle.com if you need additional information or have any
  20  * questions.
  21  */
  22 
  23 /*
  24  * This file is available under and governed by the GNU General Public
  25  * License version 2 only, as published by the Free Software Foundation.
  26  * However, the following notice accompanied the original version of this
  27  * file:
  28  *
  29  * Written by Doug Lea with assistance from members of JCP JSR-166
  30  * Expert Group and released to the public domain, as explained at
  31  * http://creativecommons.org/licenses/publicdomain
  32  */
  33 
  34 /*
  35  * @test
  36  * @bug 6445158
  37  * @summary Basic tests for Phaser
  38  * @author Chris Hegarty
  39  */
  40 
  41 import java.util.Iterator;
  42 import java.util.LinkedList;
  43 import java.util.concurrent.Phaser;
  44 import java.util.concurrent.TimeUnit;
  45 import java.util.concurrent.TimeoutException;
  46 import java.util.concurrent.atomic.AtomicInteger;
  47 import static java.util.concurrent.TimeUnit.*;
  48 
  49 public class Basic {
  50 
  51     private static void checkTerminated(final Phaser phaser) {
  52         check(phaser.isTerminated());
  53         int unarriverParties = phaser.getUnarrivedParties();
  54         int registeredParties = phaser.getRegisteredParties();
  55         equal(phaser.arrive(), -1);
  56         equal(phaser.arriveAndDeregister(), -1);
  57         equal(phaser.arriveAndAwaitAdvance(), -1);
  58         equal(phaser.bulkRegister(10), -1);
  59         equal(phaser.getPhase(), -1);
  60         equal(phaser.register(), -1);
  61         try {
  62             equal(phaser.awaitAdvanceInterruptibly(0), -1);
  63             equal(phaser.awaitAdvanceInterruptibly(0, 10, SECONDS), -1);
  64         } catch (Exception ie) {
  65             unexpected(ie);
  66         }
  67         equal(phaser.getUnarrivedParties(), unarriverParties);
  68         equal(phaser.getRegisteredParties(), registeredParties);
  69     }
  70 
  71     private static void checkResult(Arriver a, Class<? extends Throwable> c) {
  72         Throwable t = a.result();
  73         if (! ((t == null && c == null) || (c != null && c.isInstance(t)))) {
  74             //      t.printStackTrace();
  75             fail("Mismatch in thread " +
  76                  a.getName() + ": " +
  77                  t + ", " +
  78                  (c == null ? "<null>" : c.getName()));
  79         } else {
  80             pass();
  81         }
  82     }
  83 
  84     //----------------------------------------------------------------
  85     // Mechanism to get all test threads into "running" mode.
  86     //----------------------------------------------------------------
  87     private static Phaser atTheStartingGate = new Phaser(3);
  88 
  89     private static void toTheStartingGate() {
  90         try {
  91             boolean expectNextPhase = false;
  92             if (atTheStartingGate.getUnarrivedParties() == 1) {
  93                 expectNextPhase = true;
  94             }
  95             int phase = atTheStartingGate.getPhase();
  96             equal(phase, atTheStartingGate.arrive());
  97             int AwaitPhase = atTheStartingGate.awaitAdvanceInterruptibly(phase,
  98                                                         10,
  99                                                         SECONDS);
 100             if (expectNextPhase) check(AwaitPhase == (phase + 1));
 101 
 102             pass();
 103         } catch (Throwable t) {
 104             unexpected(t);
 105            // reset(atTheStartingGate);
 106             throw new Error(t);
 107         }
 108     }
 109 
 110     //----------------------------------------------------------------
 111     // Convenience methods for creating threads that call arrive,
 112     // awaitAdvance, arriveAndAwaitAdvance, awaitAdvanceInterruptibly
 113     //----------------------------------------------------------------
 114     private static abstract class Arriver extends Thread {
 115         static AtomicInteger count = new AtomicInteger(1);
 116 
 117         Arriver() {
 118             this("Arriver");
 119         }
 120 
 121         Arriver(String name) {
 122             this.setName(name + ":" + count.getAndIncrement());
 123             this.setDaemon(true);
 124         }
 125 
 126         private volatile Throwable result;
 127         private volatile int phase;
 128         protected void result(Throwable result) { this.result = result; }
 129         public Throwable result() { return this.result; }
 130         protected void phase(int phase) { this.phase = phase; }
 131         public int phase() { return this.phase; }
 132     }
 133 
 134     private static abstract class Awaiter extends Arriver {
 135         Awaiter() { super("Awaiter"); }
 136         Awaiter(String name) { super(name); }
 137     }
 138 
 139     private static Arriver arriver(final Phaser phaser) {
 140         return new Arriver() { public void run() {
 141             toTheStartingGate();
 142 
 143             try { phase(phaser.arrive()); }
 144             catch (Throwable result) { result(result); }}};
 145     }
 146 
 147     private static AtomicInteger cycleArriveAwaitAdvance = new AtomicInteger(1);
 148 
 149     private static Awaiter awaiter(final Phaser phaser) {
 150         return new Awaiter() { public void run() {
 151             toTheStartingGate();
 152 
 153             try {
 154                 if (cycleArriveAwaitAdvance.getAndIncrement() % 2 == 0)
 155                     phase(phaser.awaitAdvance(phaser.arrive()));
 156                 else
 157                     phase(phaser.arriveAndAwaitAdvance());
 158             } catch (Throwable result) { result(result); }}};
 159     }
 160 
 161     private static Awaiter awaiter(final Phaser phaser,
 162                                    final long timeout,
 163                                    final TimeUnit unit) {
 164         return new Awaiter("InterruptibleWaiter") { public void run() {
 165             toTheStartingGate();
 166 
 167             try {
 168                 if (timeout < 0)
 169                     phase(phaser.awaitAdvanceInterruptibly(phaser.arrive()));
 170                 else
 171                     phase(phaser.awaitAdvanceInterruptibly(phaser.arrive(),
 172                                                      timeout,
 173                                                      unit));
 174             } catch (Throwable result) { result(result); }}};
 175     }
 176 
 177     // Returns an infinite lazy list of all possible arriver/awaiter combinations.
 178     private static Iterator<Arriver> arriverIterator(final Phaser phaser) {
 179         return new Iterator<Arriver>() {
 180             int i = 0;
 181             public boolean hasNext() { return true; }
 182             public Arriver next() {
 183                 switch ((i++)&7) {
 184                     case 0: case 4:
 185                         return arriver(phaser);
 186                     case 1: case 5:
 187                         return awaiter(phaser);
 188                     case 2: case 6: case 7:
 189                         return awaiter(phaser, -1, SECONDS);
 190                     default:
 191                         return awaiter(phaser, 10, SECONDS); }}
 192             public void remove() {throw new UnsupportedOperationException();}};
 193     }
 194 
 195     // Returns an infinite lazy list of all possible awaiter only combinations.
 196     private static Iterator<Awaiter> awaiterIterator(final Phaser phaser) {
 197         return new Iterator<Awaiter>() {
 198             int i = 0;
 199             public boolean hasNext() { return true; }
 200             public Awaiter next() {
 201                 switch ((i++)&7) {
 202                     case 1: case 4: case 7:
 203                         return awaiter(phaser);
 204                     case 2: case 5:
 205                         return awaiter(phaser, -1, SECONDS);
 206                     default:
 207                         return awaiter(phaser, 10, SECONDS); }}
 208             public void remove() {throw new UnsupportedOperationException();}};
 209     }
 210 
 211     private static void realMain(String[] args) throws Throwable {
 212 
 213         Thread.currentThread().setName("mainThread");
 214 
 215         //----------------------------------------------------------------
 216         // Normal use
 217         //----------------------------------------------------------------
 218         try {
 219             Phaser phaser = new Phaser(3);
 220             equal(phaser.getRegisteredParties(), 3);
 221             equal(phaser.getArrivedParties(), 0);
 222             equal(phaser.getPhase(), 0);
 223             check(phaser.getRoot().equals(phaser));
 224             equal(phaser.getParent(), null);
 225             check(!phaser.isTerminated());
 226 
 227             Iterator<Arriver> arrivers = arriverIterator(phaser);
 228             int phase = 0;
 229             for (int i = 0; i < 10; i++) {
 230                 equal(phaser.getPhase(), phase++);
 231                 Arriver a1 = arrivers.next(); a1.start();
 232                 Arriver a2 = arrivers.next(); a2.start();
 233                 toTheStartingGate();
 234                 phaser.arriveAndAwaitAdvance();
 235                 a1.join();
 236                 a2.join();
 237                 checkResult(a1, null);
 238                 checkResult(a2, null);
 239                 check(!phaser.isTerminated());
 240                 equal(phaser.getRegisteredParties(), 3);
 241                 equal(phaser.getArrivedParties(), 0);
 242             }
 243         } catch (Throwable t) { unexpected(t); }
 244 
 245         //----------------------------------------------------------------
 246         // One thread interrupted
 247         //----------------------------------------------------------------
 248         try {
 249             Phaser phaser = new Phaser(3);
 250             Iterator<Arriver> arrivers = arriverIterator(phaser);
 251             int phase = phaser.getPhase();
 252             for (int i = 0; i < 4; i++) {
 253                 check(phaser.getPhase() == phase);
 254                 Awaiter a1 = awaiter(phaser, 10, SECONDS); a1.start();
 255                 Arriver a2 = arrivers.next(); a2.start();
 256                 toTheStartingGate();
 257                 a1.interrupt();
 258                 a1.join();
 259                 phaser.arriveAndAwaitAdvance();
 260                 a2.join();
 261                 checkResult(a1, InterruptedException.class);
 262                 checkResult(a2, null);
 263                 check(!phaser.isTerminated());
 264                 equal(phaser.getRegisteredParties(), 3);
 265                 equal(phaser.getArrivedParties(), 0);
 266                 phase++;
 267             }
 268         } catch (Throwable t) { unexpected(t); }
 269 
 270         //----------------------------------------------------------------
 271         // Phaser is terminated while threads are waiting
 272         //----------------------------------------------------------------
 273         try {
 274             Phaser phaser = new Phaser(3);
 275             Iterator<Awaiter> awaiters = awaiterIterator(phaser);
 276             for (int i = 0; i < 4; i++) {
 277                 Arriver a1 = awaiters.next(); a1.start();
 278                 Arriver a2 = awaiters.next(); a2.start();
 279                 toTheStartingGate();
 280                 while (phaser.getArrivedParties() < 2) Thread.yield();
 281                 phaser.forceTermination();
 282                 a1.join();
 283                 a2.join();
 284                 check(a1.phase == -1);
 285                 check(a2.phase == -1);
 286                 int arrivedParties = phaser.getArrivedParties();
 287                 checkTerminated(phaser);
 288                 equal(phaser.getArrivedParties(), arrivedParties);
 289             }
 290         } catch (Throwable t) { unexpected(t); }
 291 
 292         //----------------------------------------------------------------
 293         // Adds new unarrived parties to this phaser
 294         //----------------------------------------------------------------
 295         try {
 296             Phaser phaser = new Phaser(1);
 297             Iterator<Arriver> arrivers = arriverIterator(phaser);
 298             LinkedList<Arriver> arriverList = new LinkedList<Arriver>();
 299             int phase = phaser.getPhase();
 300             for (int i = 1; i < 5; i++) {
 301                 atTheStartingGate = new Phaser(1+(3*i));
 302                 check(phaser.getPhase() == phase);
 303                 // register 3 more
 304                 phaser.register(); phaser.register(); phaser.register();
 305                 for (int z=0; z<(3*i); z++) {
 306                    arriverList.add(arrivers.next());
 307                 }
 308                 for (Arriver arriver : arriverList)
 309                     arriver.start();
 310 
 311                 toTheStartingGate();
 312                 phaser.arriveAndAwaitAdvance();
 313 
 314                 for (Arriver arriver : arriverList) {
 315                     arriver.join();
 316                     checkResult(arriver, null);
 317                 }
 318                 equal(phaser.getRegisteredParties(), 1 + (3*i));
 319                 equal(phaser.getArrivedParties(), 0);
 320                 arriverList.clear();
 321                 phase++;
 322             }
 323             atTheStartingGate = new Phaser(3);
 324         } catch (Throwable t) { unexpected(t); }
 325 
 326         //----------------------------------------------------------------
 327         // One thread timed out
 328         //----------------------------------------------------------------
 329         try {
 330             Phaser phaser = new Phaser(3);
 331             Iterator<Arriver> arrivers = arriverIterator(phaser);
 332             for (long timeout : new long[] { 0L, 5L }) {
 333                 for (int i = 0; i < 2; i++) {
 334                     Awaiter a1 = awaiter(phaser, timeout, SECONDS); a1.start();
 335                     Arriver a2 = arrivers.next();                   a2.start();
 336                     toTheStartingGate();
 337                     a1.join();
 338                     checkResult(a1, TimeoutException.class);
 339                     phaser.arrive();
 340                     a2.join();
 341                     checkResult(a2, null);
 342                     check(!phaser.isTerminated());
 343                 }
 344             }
 345         } catch (Throwable t) { unexpected(t); }
 346 
 347         //----------------------------------------------------------------
 348         // Barrier action completed normally
 349         //----------------------------------------------------------------
 350         try {
 351             final AtomicInteger count = new AtomicInteger(0);
 352             final Phaser[] kludge = new Phaser[1];
 353             Phaser phaser = new Phaser(3) {
 354                 @Override
 355                 protected boolean onAdvance(int phase, int registeredParties) {
 356                     int countPhase = count.getAndIncrement();
 357                     equal(countPhase, phase);
 358                     equal(kludge[0].getPhase(), phase);
 359                     equal(kludge[0].getRegisteredParties(), registeredParties);
 360                     if (phase >= 3)
 361                         return true; // terminate
 362 
 363                     return false;
 364                 }
 365             };
 366             kludge[0] = phaser;
 367             equal(phaser.getRegisteredParties(), 3);
 368             Iterator<Awaiter> awaiters = awaiterIterator(phaser);
 369             for (int i = 0; i < 4; i++) {
 370                 Awaiter a1 = awaiters.next(); a1.start();
 371                 Awaiter a2 = awaiters.next(); a2.start();
 372                 toTheStartingGate();
 373                 while (phaser.getArrivedParties() < 2) Thread.yield();
 374                 phaser.arrive();
 375                 a1.join();
 376                 a2.join();
 377                 checkResult(a1, null);
 378                 checkResult(a2, null);
 379                 equal(count.get(), i+1);
 380                 if (i < 3) {
 381                     check(!phaser.isTerminated());
 382                     equal(phaser.getRegisteredParties(), 3);
 383                     equal(phaser.getArrivedParties(), 0);
 384                     equal(phaser.getUnarrivedParties(), 3);
 385                     equal(phaser.getPhase(), count.get());
 386                 } else
 387                     checkTerminated(phaser);
 388             }
 389         } catch (Throwable t) { unexpected(t); }
 390 
 391     }
 392 
 393     //--------------------- Infrastructure ---------------------------
 394     static volatile int passed = 0, failed = 0;
 395     static void pass() {passed++;}
 396     static void fail() {failed++; Thread.dumpStack();}
 397     static void fail(String msg) {System.out.println(msg); fail();}
 398     static void unexpected(Throwable t) {failed++; t.printStackTrace();}
 399     static void check(boolean cond) {if (cond) pass(); else fail();}
 400     static void equal(Object x, Object y) {
 401         if (x == null ? y == null : x.equals(y)) pass();
 402         else fail(x + " not equal to " + y);}
 403     public static void main(String[] args) throws Throwable {
 404         try {realMain(args);} catch (Throwable t) {unexpected(t);}
 405         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
 406         if (failed > 0) throw new AssertionError("Some tests failed");}
 407 }