1 /*
   2  * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug     6267833
  27  * @summary Tests for invokeAny, invokeAll
  28  * @author  Martin Buchholz
  29  */
  30 
  31 import static java.util.concurrent.TimeUnit.NANOSECONDS;
  32 import static java.util.concurrent.TimeUnit.SECONDS;
  33 
  34 import java.util.List;
  35 import java.util.stream.Collectors;
  36 import java.util.stream.IntStream;
  37 import java.util.concurrent.Callable;
  38 import java.util.concurrent.CountDownLatch;
  39 import java.util.concurrent.CyclicBarrier;
  40 import java.util.concurrent.ExecutorService;
  41 import java.util.concurrent.Executors;
  42 import java.util.concurrent.Future;
  43 import java.util.concurrent.ThreadLocalRandom;
  44 import java.util.concurrent.atomic.AtomicLong;
  45 
  46 public class Invoke {
  47     static volatile int passed = 0, failed = 0;
  48 
  49     static void fail(String msg) {
  50         failed++;
  51         new AssertionError(msg).printStackTrace();
  52     }
  53 
  54     static void pass() {
  55         passed++;
  56     }
  57 
  58     static void unexpected(Throwable t) {
  59         failed++;
  60         t.printStackTrace();
  61     }
  62 
  63     static void check(boolean condition, String msg) {
  64         if (condition) pass(); else fail(msg);
  65     }
  66 
  67     static void check(boolean condition) {
  68         check(condition, "Assertion failure");
  69     }
  70 
  71     static long secondsElapsedSince(long startTime) {
  72         return NANOSECONDS.toSeconds(System.nanoTime() - startTime);
  73     }
  74 
  75     static void awaitInterrupt(long timeoutSeconds) {
  76         long startTime = System.nanoTime();
  77         try {
  78             Thread.sleep(SECONDS.toMillis(timeoutSeconds));
  79             fail("timed out waiting for interrupt");
  80         } catch (InterruptedException expected) {
  81             check(secondsElapsedSince(startTime) < timeoutSeconds);
  82         }
  83     }
  84 
  85     public static void main(String[] args) {
  86         try {
  87             testInvokeAll();
  88             testInvokeAny();
  89             testInvokeAny_cancellationInterrupt();
  90         } catch (Throwable t) {  unexpected(t); }
  91 
  92         if (failed > 0)
  93             throw new Error(
  94                     String.format("Passed = %d, failed = %d", passed, failed));
  95     }
  96 
  97     static final long timeoutSeconds = 10L;
  98 
  99     static void testInvokeAll() throws Throwable {
 100         final ThreadLocalRandom rnd = ThreadLocalRandom.current();
 101         final int nThreads = rnd.nextInt(2, 7);
 102         final boolean timed = rnd.nextBoolean();
 103         final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
 104         final AtomicLong count = new AtomicLong(0);
 105         class Task implements Callable<Long> {
 106             public Long call() throws Exception {
 107                 return count.incrementAndGet();
 108             }
 109         }
 110 
 111         try {
 112             final List<Task> tasks =
 113                 IntStream.range(0, nThreads)
 114                 .mapToObj(i -> new Task())
 115                 .collect(Collectors.toList());
 116 
 117             List<Future<Long>> futures;
 118             if (timed) {
 119                 long startTime = System.nanoTime();
 120                 futures = pool.invokeAll(tasks, timeoutSeconds, SECONDS);
 121                 check(secondsElapsedSince(startTime) < timeoutSeconds);
 122             }
 123             else
 124                 futures = pool.invokeAll(tasks);
 125             check(futures.size() == tasks.size());
 126             check(count.get() == tasks.size());
 127 
 128             long gauss = 0;
 129             for (Future<Long> future : futures) gauss += future.get();
 130             check(gauss == (tasks.size()+1)*tasks.size()/2);
 131 
 132             pool.shutdown();
 133             check(pool.awaitTermination(10L, SECONDS));
 134         } finally {
 135             pool.shutdownNow();
 136         }
 137     }
 138 
 139     static void testInvokeAny() throws Throwable {
 140         final ThreadLocalRandom rnd = ThreadLocalRandom.current();
 141         final boolean timed = rnd.nextBoolean();
 142         final ExecutorService pool = Executors.newSingleThreadExecutor();
 143         final AtomicLong count = new AtomicLong(0);
 144         final CountDownLatch invokeAnyDone = new CountDownLatch(1);
 145         class Task implements Callable<Long> {
 146             public Long call() throws Exception {
 147                 long x = count.incrementAndGet();
 148                 check(x <= 2);
 149                 if (x == 2) {
 150                     // wait for main thread to interrupt us ...
 151                     awaitInterrupt(timeoutSeconds);
 152                     // ... and then for invokeAny to return
 153                     check(invokeAnyDone.await(timeoutSeconds, SECONDS));
 154                 }
 155                 return x;
 156             }
 157         }
 158 
 159         try {
 160             final List<Task> tasks =
 161                 IntStream.range(0, rnd.nextInt(1, 7))
 162                 .mapToObj(i -> new Task())
 163                 .collect(Collectors.toList());
 164 
 165             long val;
 166             if (timed) {
 167                 long startTime = System.nanoTime();
 168                 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
 169                 check(secondsElapsedSince(startTime) < timeoutSeconds);
 170             }
 171             else
 172                 val = pool.invokeAny(tasks);
 173             check(val == 1);
 174             invokeAnyDone.countDown();
 175 
 176             // inherent race between main thread interrupt and
 177             // start of second task
 178             check(count.get() == 1 || count.get() == 2);
 179 
 180             pool.shutdown();
 181             check(pool.awaitTermination(timeoutSeconds, SECONDS));
 182         } finally {
 183             pool.shutdownNow();
 184         }
 185     }
 186 
 187     /**
 188      * Every remaining running task is sent an interrupt for cancellation.
 189      */
 190     static void testInvokeAny_cancellationInterrupt() throws Throwable {
 191         final ThreadLocalRandom rnd = ThreadLocalRandom.current();
 192         final int nThreads = rnd.nextInt(2, 7);
 193         final boolean timed = rnd.nextBoolean();
 194         final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
 195         final AtomicLong count = new AtomicLong(0);
 196         final AtomicLong interruptedCount = new AtomicLong(0);
 197         final CyclicBarrier allStarted = new CyclicBarrier(nThreads);
 198         class Task implements Callable<Long> {
 199             public Long call() throws Exception {
 200                 allStarted.await();
 201                 long x = count.incrementAndGet();
 202                 if (x > 1)
 203                     // main thread will interrupt us
 204                     awaitInterrupt(timeoutSeconds);
 205                 return x;
 206             }
 207         }
 208 
 209         try {
 210             final List<Task> tasks =
 211                 IntStream.range(0, nThreads)
 212                 .mapToObj(i -> new Task())
 213                 .collect(Collectors.toList());
 214 
 215             long val;
 216             if (timed) {
 217                 long startTime = System.nanoTime();
 218                 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
 219                 check(secondsElapsedSince(startTime) < timeoutSeconds);
 220             }
 221             else
 222                 val = pool.invokeAny(tasks);
 223             check(val == 1);
 224 
 225             pool.shutdown();
 226             check(pool.awaitTermination(timeoutSeconds, SECONDS));
 227 
 228             // Check after shutdown to avoid race
 229             check(count.get() == nThreads);
 230         } finally {
 231             pool.shutdownNow();
 232         }
 233     }
 234 }