1 /*
   2  * Copyright (c) 2012, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 import java.util.concurrent.atomic.AtomicInteger;
  27 import java.util.function.Supplier;
  28 import org.testng.annotations.Test;
  29 import static org.testng.Assert.*;
  30 
  31 /**
  32  * @test
  33  * @run testng ThreadLocalSupplierTest
  34  * @summary tests ThreadLocal.withInitial(<Supplier>).
  35  * Adapted from java.lang.Basic functional test of ThreadLocal
  36  *
  37  * @author Jim Gish <jim.gish@oracle.com>
  38  */
  39 @Test
  40 public class ThreadLocalSupplierTest {
  41 
  42     static final class IntegerSupplier implements Supplier<Integer> {
  43 
  44         private final AtomicInteger supply = new AtomicInteger(0);
  45 
  46         @Override
  47         public Integer get() {
  48             return supply.getAndIncrement();
  49         }
  50 
  51         public int numCalls() {
  52             return supply.intValue();
  53         }
  54     }
  55 
  56     static IntegerSupplier theSupply = new IntegerSupplier();
  57 
  58     static final class MyThreadLocal extends ThreadLocal<Integer> {
  59 
  60         private final ThreadLocal<Integer> delegate;
  61 
  62         public volatile boolean everCalled;
  63 
  64         public MyThreadLocal(Supplier<Integer> supplier) {
  65             delegate = ThreadLocal.<Integer>withInitial(supplier);
  66         }
  67 
  68         @Override
  69         public Integer get() {
  70             return delegate.get();
  71         }
  72 
  73         @Override
  74         protected synchronized Integer initialValue() {
  75             // this should never be called since we are using the factory instead
  76             everCalled = true;
  77             return null;
  78         }
  79     }
  80 
  81     /**
  82      * Our one and only ThreadLocal from which we get thread ids using a
  83      * supplier which simply increments a counter on each call of get().
  84      */
  85     static MyThreadLocal threadLocal = new MyThreadLocal(theSupply);
  86 
  87     public void testMultiThread() throws Exception {
  88         final int threadCount = 500;
  89         final Thread th[] = new Thread[threadCount];
  90         final boolean visited[] = new boolean[threadCount];
  91 
  92         // Create and start the threads
  93         for (int i = 0; i < threadCount; i++) {
  94             th[i] = new Thread() {
  95                 @Override
  96                 public void run() {
  97                     final int threadId = threadLocal.get();
  98                     assertFalse(visited[threadId], "visited[" + threadId + "]=" + visited[threadId]);
  99                     visited[threadId] = true;
 100                     // check the get() again
 101                     final int secondCheckThreadId = threadLocal.get();
 102                     assertEquals(secondCheckThreadId, threadId);
 103                 }
 104             };
 105             th[i].start();
 106         }
 107 
 108         // Wait for the threads to finish
 109         for (int i = 0; i < threadCount; i++) {
 110             th[i].join();
 111         }
 112 
 113         assertEquals(theSupply.numCalls(), threadCount);
 114         // make sure the provided initialValue() has not been called
 115         assertFalse(threadLocal.everCalled);
 116         // Check results
 117         for (int i = 0; i < threadCount; i++) {
 118             assertTrue(visited[i], "visited[" + i + "]=" + visited[i]);
 119         }
 120     }
 121 
 122     public void testSimple() {
 123         final String expected = "OneWithEverything";
 124         final ThreadLocal<String> threadLocal = ThreadLocal.<String>withInitial(() -> expected);
 125         assertEquals(expected, threadLocal.get());
 126     }
 127 }