1 /* 2 * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @bug 6267833 27 * @summary Tests for invokeAny, invokeAll 28 * @author Martin Buchholz 29 */ 30 31 import static java.util.concurrent.TimeUnit.NANOSECONDS; 32 import static java.util.concurrent.TimeUnit.SECONDS; 33 34 import java.util.List; 35 import java.util.stream.Collectors; 36 import java.util.stream.IntStream; 37 import java.util.concurrent.Callable; 38 import java.util.concurrent.CountDownLatch; 39 import java.util.concurrent.CyclicBarrier; 40 import java.util.concurrent.ExecutorService; 41 import java.util.concurrent.Executors; 42 import java.util.concurrent.Future; 43 import java.util.concurrent.ThreadLocalRandom; 44 import java.util.concurrent.atomic.AtomicLong; 45 46 public class Invoke { 47 static volatile int passed = 0, failed = 0; 48 49 static void fail(String msg) { 50 failed++; 51 new AssertionError(msg).printStackTrace(); 52 } 53 54 static void pass() { 55 passed++; 56 } 57 58 static void unexpected(Throwable t) { 59 failed++; 60 t.printStackTrace(); 61 } 62 63 static void check(boolean condition, String msg) { 64 if (condition) pass(); else fail(msg); 65 } 66 67 static void check(boolean condition) { 68 check(condition, "Assertion failure"); 69 } 70 71 static long secondsElapsedSince(long startTime) { 72 return NANOSECONDS.toSeconds(System.nanoTime() - startTime); 73 } 74 75 static void awaitInterrupt(long timeoutSeconds) { 76 long startTime = System.nanoTime(); 77 try { 78 Thread.sleep(SECONDS.toMillis(timeoutSeconds)); 79 fail("timed out waiting for interrupt"); 80 } catch (InterruptedException expected) { 81 check(secondsElapsedSince(startTime) < timeoutSeconds); 82 } 83 } 84 85 public static void main(String[] args) { 86 try { 87 testInvokeAll(); 88 testInvokeAny(); 89 testInvokeAny_cancellationInterrupt(); 90 } catch (Throwable t) { unexpected(t); } 91 92 if (failed > 0) 93 throw new Error( 94 String.format("Passed = %d, failed = %d", passed, failed)); 95 } 96 97 static final long timeoutSeconds = 10L; 98 99 static void testInvokeAll() throws Throwable { 100 final ThreadLocalRandom rnd = ThreadLocalRandom.current(); 101 final int nThreads = rnd.nextInt(2, 7); 102 final boolean timed = rnd.nextBoolean(); 103 final ExecutorService pool = Executors.newFixedThreadPool(nThreads); 104 final AtomicLong count = new AtomicLong(0); 105 class Task implements Callable<Long> { 106 public Long call() throws Exception { 107 return count.incrementAndGet(); 108 } 109 } 110 111 try { 112 final List<Task> tasks = 113 IntStream.range(0, nThreads) 114 .mapToObj(i -> new Task()) 115 .collect(Collectors.toList()); 116 117 List<Future<Long>> futures; 118 if (timed) { 119 long startTime = System.nanoTime(); 120 futures = pool.invokeAll(tasks, timeoutSeconds, SECONDS); 121 check(secondsElapsedSince(startTime) < timeoutSeconds); 122 } 123 else 124 futures = pool.invokeAll(tasks); 125 check(futures.size() == tasks.size()); 126 check(count.get() == tasks.size()); 127 128 long gauss = 0; 129 for (Future<Long> future : futures) gauss += future.get(); 130 check(gauss == (tasks.size()+1)*tasks.size()/2); 131 132 pool.shutdown(); 133 check(pool.awaitTermination(10L, SECONDS)); 134 } finally { 135 pool.shutdownNow(); 136 } 137 } 138 139 static void testInvokeAny() throws Throwable { 140 final ThreadLocalRandom rnd = ThreadLocalRandom.current(); 141 final boolean timed = rnd.nextBoolean(); 142 final ExecutorService pool = Executors.newSingleThreadExecutor(); 143 final AtomicLong count = new AtomicLong(0); 144 final CountDownLatch invokeAnyDone = new CountDownLatch(1); 145 class Task implements Callable<Long> { 146 public Long call() throws Exception { 147 long x = count.incrementAndGet(); 148 check(x <= 2); 149 if (x == 2) { 150 // wait for main thread to interrupt us ... 151 awaitInterrupt(timeoutSeconds); 152 // ... and then for invokeAny to return 153 check(invokeAnyDone.await(timeoutSeconds, SECONDS)); 154 } 155 return x; 156 } 157 } 158 159 try { 160 final List<Task> tasks = 161 IntStream.range(0, rnd.nextInt(1, 7)) 162 .mapToObj(i -> new Task()) 163 .collect(Collectors.toList()); 164 165 long val; 166 if (timed) { 167 long startTime = System.nanoTime(); 168 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS); 169 check(secondsElapsedSince(startTime) < timeoutSeconds); 170 } 171 else 172 val = pool.invokeAny(tasks); 173 check(val == 1); 174 invokeAnyDone.countDown(); 175 176 // inherent race between main thread interrupt and 177 // start of second task 178 check(count.get() == 1 || count.get() == 2); 179 180 pool.shutdown(); 181 check(pool.awaitTermination(timeoutSeconds, SECONDS)); 182 } finally { 183 pool.shutdownNow(); 184 } 185 } 186 187 /** 188 * Every remaining running task is sent an interrupt for cancellation. 189 */ 190 static void testInvokeAny_cancellationInterrupt() throws Throwable { 191 final ThreadLocalRandom rnd = ThreadLocalRandom.current(); 192 final int nThreads = rnd.nextInt(2, 7); 193 final boolean timed = rnd.nextBoolean(); 194 final ExecutorService pool = Executors.newFixedThreadPool(nThreads); 195 final AtomicLong count = new AtomicLong(0); 196 final AtomicLong interruptedCount = new AtomicLong(0); 197 final CyclicBarrier allStarted = new CyclicBarrier(nThreads); 198 class Task implements Callable<Long> { 199 public Long call() throws Exception { 200 allStarted.await(); 201 long x = count.incrementAndGet(); 202 if (x > 1) 203 // main thread will interrupt us 204 awaitInterrupt(timeoutSeconds); 205 return x; 206 } 207 } 208 209 try { 210 final List<Task> tasks = 211 IntStream.range(0, nThreads) 212 .mapToObj(i -> new Task()) 213 .collect(Collectors.toList()); 214 215 long val; 216 if (timed) { 217 long startTime = System.nanoTime(); 218 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS); 219 check(secondsElapsedSince(startTime) < timeoutSeconds); 220 } 221 else 222 val = pool.invokeAny(tasks); 223 check(val == 1); 224 225 pool.shutdown(); 226 check(pool.awaitTermination(timeoutSeconds, SECONDS)); 227 228 // Check after shutdown to avoid race 229 check(count.get() == nThreads); 230 } finally { 231 pool.shutdownNow(); 232 } 233 } 234 }