1 /* 2 * Copyright (c) 2005, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @bug 6267833 27 * @summary Tests for invokeAny, invokeAll 28 * @author Martin Buchholz 29 */ 30 31 import static java.util.concurrent.TimeUnit.NANOSECONDS; 32 import static java.util.concurrent.TimeUnit.SECONDS; 33 34 import java.util.List; 35 import java.util.stream.Collectors; 36 import java.util.stream.IntStream; 37 import java.util.concurrent.Callable; 38 import java.util.concurrent.CountDownLatch; 39 import java.util.concurrent.CyclicBarrier; 40 import java.util.concurrent.ExecutorService; 41 import java.util.concurrent.Executors; 42 import java.util.concurrent.Future; 43 import java.util.concurrent.ThreadLocalRandom; 44 import java.util.concurrent.atomic.AtomicLong; 45 46 public class Invoke { 47 static volatile int passed = 0, failed = 0; 48 49 static void fail(String msg) { 50 failed++; 51 new AssertionError(msg).printStackTrace(); 52 } 53 54 static void pass() { 55 passed++; 56 } 57 58 static void unexpected(Throwable t) { 59 failed++; 60 t.printStackTrace(); 61 } 62 63 static void check(boolean condition, String msg) { 64 if (condition) pass(); else fail(msg); 65 } 66 67 static void check(boolean condition) { 68 check(condition, "Assertion failure"); 69 } 70 71 static long secondsElapsedSince(long startTime) { 72 return NANOSECONDS.toSeconds(System.nanoTime() - startTime); 73 } 74 75 static void awaitInterrupt(long timeoutSeconds) { 76 long startTime = System.nanoTime(); 77 try { 78 Thread.sleep(SECONDS.toMillis(timeoutSeconds)); 79 fail("timed out waiting for interrupt"); 80 } catch (InterruptedException expected) { 81 check(secondsElapsedSince(startTime) < timeoutSeconds); 82 } 83 } 84 85 public static void main(String[] args) { 86 try { 87 for (int nThreads = 1; nThreads <= 6; ++nThreads) { 88 // untimed 89 testInvokeAll(nThreads, false); 90 testInvokeAny(nThreads, false); 91 testInvokeAny_cancellationInterrupt(nThreads, false); 92 // timed 93 testInvokeAll(nThreads, true); 94 testInvokeAny(nThreads, true); 95 testInvokeAny_cancellationInterrupt(nThreads, true); 96 } 97 } catch (Throwable t) { unexpected(t); } 98 99 if (failed > 0) 100 throw new Error( 101 String.format("Passed = %d, failed = %d", passed, failed)); 102 } 103 104 static final long timeoutSeconds = 10L; 105 106 static void testInvokeAll(int nThreads, boolean timed) throws Throwable { 107 final ThreadLocalRandom rnd = ThreadLocalRandom.current(); 108 final ExecutorService pool = Executors.newFixedThreadPool(nThreads); 109 final AtomicLong count = new AtomicLong(0); 110 class Task implements Callable<Long> { 111 public Long call() throws Exception { 112 return count.incrementAndGet(); 113 } 114 } 115 116 try { 117 final List<Task> tasks = 118 IntStream.range(0, nThreads) 119 .mapToObj(i -> new Task()) 120 .collect(Collectors.toList()); 121 122 List<Future<Long>> futures; 123 if (timed) { 124 long startTime = System.nanoTime(); 125 futures = pool.invokeAll(tasks, timeoutSeconds, SECONDS); 126 check(secondsElapsedSince(startTime) < timeoutSeconds); 127 } 128 else 129 futures = pool.invokeAll(tasks); 130 check(futures.size() == tasks.size()); 131 check(count.get() == tasks.size()); 132 133 long gauss = 0; 134 for (Future<Long> future : futures) gauss += future.get(); 135 check(gauss == (tasks.size()+1)*tasks.size()/2); 136 137 pool.shutdown(); 138 check(pool.awaitTermination(10L, SECONDS)); 139 } finally { 140 pool.shutdownNow(); 141 } 142 } 143 144 static void testInvokeAny(int nThreads, boolean timed) throws Throwable { 145 final ThreadLocalRandom rnd = ThreadLocalRandom.current(); 146 final ExecutorService pool = Executors.newFixedThreadPool(nThreads); 147 final AtomicLong count = new AtomicLong(0); 148 final CountDownLatch invokeAnyDone = new CountDownLatch(1); 149 class Task implements Callable<Long> { 150 public Long call() throws Exception { 151 long x = count.incrementAndGet(); 152 if (x > 1) { 153 // wait for main thread to interrupt us ... 154 awaitInterrupt(timeoutSeconds); 155 // ... and then for invokeAny to return 156 check(invokeAnyDone.await(timeoutSeconds, SECONDS)); 157 } 158 return x; 159 } 160 } 161 162 try { 163 final List<Task> tasks = 164 IntStream.range(0, rnd.nextInt(1, 7)) 165 .mapToObj(i -> new Task()) 166 .collect(Collectors.toList()); 167 168 long val; 169 if (timed) { 170 long startTime = System.nanoTime(); 171 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS); 172 check(secondsElapsedSince(startTime) < timeoutSeconds); 173 } 174 else 175 val = pool.invokeAny(tasks); 176 check(val == 1); 177 invokeAnyDone.countDown(); 178 179 pool.shutdown(); 180 check(pool.awaitTermination(timeoutSeconds, SECONDS)); 181 182 long c = count.get(); 183 check(c >= 1 && c <= tasks.size()); 184 185 } finally { 186 pool.shutdownNow(); 187 } 188 } 189 190 /** 191 * Every remaining running task is sent an interrupt for cancellation. 192 */ 193 static void testInvokeAny_cancellationInterrupt(int nThreads, boolean timed) throws Throwable { 194 final ThreadLocalRandom rnd = ThreadLocalRandom.current(); 195 final ExecutorService pool = Executors.newFixedThreadPool(nThreads); 196 final AtomicLong count = new AtomicLong(0); 197 final AtomicLong interruptedCount = new AtomicLong(0); 198 final CyclicBarrier allStarted = new CyclicBarrier(nThreads); 199 class Task implements Callable<Long> { 200 public Long call() throws Exception { 201 long x = count.incrementAndGet(); 202 allStarted.await(); 203 if (x > 1) 204 // main thread will interrupt us 205 awaitInterrupt(timeoutSeconds); 206 return x; 207 } 208 } 209 210 try { 211 final List<Task> tasks = 212 IntStream.range(0, nThreads) 213 .mapToObj(i -> new Task()) 214 .collect(Collectors.toList()); 215 216 long val; 217 if (timed) { 218 long startTime = System.nanoTime(); 219 val = pool.invokeAny(tasks, timeoutSeconds, SECONDS); 220 check(secondsElapsedSince(startTime) < timeoutSeconds); 221 } 222 else 223 val = pool.invokeAny(tasks); 224 check(val == 1); 225 226 pool.shutdown(); 227 check(pool.awaitTermination(timeoutSeconds, SECONDS)); 228 229 // Check after shutdown to avoid race 230 check(count.get() == nThreads); 231 } finally { 232 pool.shutdownNow(); 233 } 234 } 235 }