1 /*
   2  * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug     6267833
  27  * @summary Tests for invokeAny, invokeAll
  28  * @author  Martin Buchholz
  29  */
  30 
  31 import static java.util.concurrent.TimeUnit.NANOSECONDS;
  32 import static java.util.concurrent.TimeUnit.SECONDS;
  33 
  34 import java.util.List;
  35 import java.util.stream.Collectors;
  36 import java.util.stream.IntStream;
  37 import java.util.concurrent.Callable;
  38 import java.util.concurrent.CountDownLatch;
  39 import java.util.concurrent.CyclicBarrier;
  40 import java.util.concurrent.ExecutorService;
  41 import java.util.concurrent.Executors;
  42 import java.util.concurrent.Future;
  43 import java.util.concurrent.ThreadLocalRandom;
  44 import java.util.concurrent.atomic.AtomicLong;
  45 
  46 public class Invoke {
  47     static volatile int passed = 0, failed = 0;
  48 
  49     static void fail(String msg) {
  50         failed++;
  51         new AssertionError(msg).printStackTrace();
  52     }
  53 
  54     static void pass() {
  55         passed++;
  56     }
  57 
  58     static void unexpected(Throwable t) {
  59         failed++;
  60         t.printStackTrace();
  61     }
  62 
  63     static void check(boolean condition, String msg) {
  64         if (condition) pass(); else fail(msg);
  65     }
  66 
  67     static void check(boolean condition) {
  68         check(condition, "Assertion failure");
  69     }
  70 
  71     static long secondsElapsedSince(long startTime) {
  72         return NANOSECONDS.toSeconds(System.nanoTime() - startTime);
  73     }
  74 
  75     static void awaitInterrupt(long timeoutSeconds) {
  76         long startTime = System.nanoTime();
  77         try {
  78             Thread.sleep(SECONDS.toMillis(timeoutSeconds));
  79             fail("timed out waiting for interrupt");
  80         } catch (InterruptedException expected) {
  81             check(secondsElapsedSince(startTime) < timeoutSeconds);
  82         }
  83     }
  84 
  85     public static void main(String[] args) {
  86         try {
  87             for (int nThreads = 1; nThreads <= 6; ++nThreads) {
  88                 // untimed
  89                 testInvokeAll(nThreads, false);
  90                 testInvokeAny(nThreads, false);
  91                 testInvokeAny_cancellationInterrupt(nThreads, false);
  92                 // timed
  93                 testInvokeAll(nThreads, true);
  94                 testInvokeAny(nThreads, true);
  95                 testInvokeAny_cancellationInterrupt(nThreads, true);
  96             }
  97         } catch (Throwable t) {  unexpected(t); }
  98 
  99         if (failed > 0)
 100             throw new Error(
 101                     String.format("Passed = %d, failed = %d", passed, failed));
 102     }
 103 
 104     static final long timeoutSeconds = 10L;
 105 
 106     static void testInvokeAll(int nThreads, boolean timed) throws Throwable {
 107         final ThreadLocalRandom rnd = ThreadLocalRandom.current();
 108         final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
 109         final AtomicLong count = new AtomicLong(0);
 110         class Task implements Callable<Long> {
 111             public Long call() throws Exception {
 112                 return count.incrementAndGet();
 113             }
 114         }
 115 
 116         try {
 117             final List<Task> tasks =
 118                 IntStream.range(0, nThreads)
 119                 .mapToObj(i -> new Task())
 120                 .collect(Collectors.toList());
 121 
 122             List<Future<Long>> futures;
 123             if (timed) {
 124                 long startTime = System.nanoTime();
 125                 futures = pool.invokeAll(tasks, timeoutSeconds, SECONDS);
 126                 check(secondsElapsedSince(startTime) < timeoutSeconds);
 127             }
 128             else
 129                 futures = pool.invokeAll(tasks);
 130             check(futures.size() == tasks.size());
 131             check(count.get() == tasks.size());
 132 
 133             long gauss = 0;
 134             for (Future<Long> future : futures) gauss += future.get();
 135             check(gauss == (tasks.size()+1)*tasks.size()/2);
 136 
 137             pool.shutdown();
 138             check(pool.awaitTermination(10L, SECONDS));
 139         } finally {
 140             pool.shutdownNow();
 141         }
 142     }
 143 
 144     static void testInvokeAny(int nThreads, boolean timed) throws Throwable {
 145         final ThreadLocalRandom rnd = ThreadLocalRandom.current();
 146         final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
 147         final AtomicLong count = new AtomicLong(0);
 148         final CountDownLatch invokeAnyDone = new CountDownLatch(1);
 149         class Task implements Callable<Long> {
 150             public Long call() throws Exception {
 151                 long x = count.incrementAndGet();
 152                 if (x > 1) {
 153                     // wait for main thread to interrupt us ...
 154                     awaitInterrupt(timeoutSeconds);
 155                     // ... and then for invokeAny to return
 156                     check(invokeAnyDone.await(timeoutSeconds, SECONDS));
 157                 }
 158                 return x;
 159             }
 160         }
 161 
 162         try {
 163             final List<Task> tasks =
 164                 IntStream.range(0, rnd.nextInt(1, 7))
 165                 .mapToObj(i -> new Task())
 166                 .collect(Collectors.toList());
 167 
 168             long val;
 169             if (timed) {
 170                 long startTime = System.nanoTime();
 171                 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
 172                 check(secondsElapsedSince(startTime) < timeoutSeconds);
 173             }
 174             else
 175                 val = pool.invokeAny(tasks);
 176             check(val == 1);
 177             invokeAnyDone.countDown();
 178 
 179             pool.shutdown();
 180             check(pool.awaitTermination(timeoutSeconds, SECONDS));
 181 
 182             long c = count.get();
 183             check(c >= 1 && c <= tasks.size());
 184 
 185         } finally {
 186             pool.shutdownNow();
 187         }
 188     }
 189 
 190     /**
 191      * Every remaining running task is sent an interrupt for cancellation.
 192      */
 193     static void testInvokeAny_cancellationInterrupt(int nThreads, boolean timed) throws Throwable {
 194         final ThreadLocalRandom rnd = ThreadLocalRandom.current();
 195         final ExecutorService pool = Executors.newFixedThreadPool(nThreads);
 196         final AtomicLong count = new AtomicLong(0);
 197         final AtomicLong interruptedCount = new AtomicLong(0);
 198         final CyclicBarrier allStarted = new CyclicBarrier(nThreads);
 199         class Task implements Callable<Long> {
 200             public Long call() throws Exception {
 201                 long x = count.incrementAndGet();
 202                 allStarted.await();
 203                 if (x > 1)
 204                     // main thread will interrupt us
 205                     awaitInterrupt(timeoutSeconds);
 206                 return x;
 207             }
 208         }
 209 
 210         try {
 211             final List<Task> tasks =
 212                 IntStream.range(0, nThreads)
 213                 .mapToObj(i -> new Task())
 214                 .collect(Collectors.toList());
 215 
 216             long val;
 217             if (timed) {
 218                 long startTime = System.nanoTime();
 219                 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS);
 220                 check(secondsElapsedSince(startTime) < timeoutSeconds);
 221             }
 222             else
 223                 val = pool.invokeAny(tasks);
 224             check(val == 1);
 225 
 226             pool.shutdown();
 227             check(pool.awaitTermination(timeoutSeconds, SECONDS));
 228 
 229             // Check after shutdown to avoid race
 230             check(count.get() == nThreads);
 231         } finally {
 232             pool.shutdownNow();
 233         }
 234     }
 235 }