1 /*
   2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @run testng TestSpliterator
  27  */
  28 
  29 import jdk.incubator.foreign.MemoryAddress;
  30 import jdk.incubator.foreign.MemoryLayout;
  31 import jdk.incubator.foreign.MemoryLayouts;
  32 import jdk.incubator.foreign.MemorySegment;
  33 import jdk.incubator.foreign.SequenceLayout;
  34 
  35 import java.lang.invoke.VarHandle;
  36 import java.util.LinkedList;
  37 import java.util.List;
  38 import java.util.Map;
  39 import java.util.Spliterator;
  40 import java.util.concurrent.CountedCompleter;
  41 import java.util.concurrent.RecursiveTask;
  42 import java.util.concurrent.atomic.AtomicLong;
  43 import java.util.function.Consumer;
  44 import java.util.function.Supplier;
  45 import java.util.stream.LongStream;
  46 import java.util.stream.StreamSupport;
  47 
  48 import org.testng.annotations.*;
  49 import static jdk.incubator.foreign.MemorySegment.*;
  50 import static org.testng.Assert.*;
  51 
  52 public class TestSpliterator {
  53 
  54     static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT)
  55             .varHandle(int.class, MemoryLayout.PathElement.sequenceElement());
  56 
  57     final static int CARRIER_SIZE = 4;
  58 
  59     @Test(dataProvider = "splits")
  60     public void testSum(int size, int threshold) {
  61         SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT);
  62 
  63         //setup
  64         MemorySegment segment = MemorySegment.allocateNative(layout);
  65         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
  66             INT_HANDLE.set(segment.baseAddress(), (long) i, i);
  67         }
  68         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
  69         //serial
  70         long serial = sum(0, segment);
  71         assertEquals(serial, expected);
  72         //parallel counted completer
  73         long parallelCounted = new SumSegmentCounted(null, MemorySegment.spliterator(segment, layout), threshold).invoke();
  74         assertEquals(parallelCounted, expected);
  75         //parallel recursive action
  76         long parallelRecursive = new SumSegmentRecursive(MemorySegment.spliterator(segment, layout), threshold).invoke();
  77         assertEquals(parallelRecursive, expected);
  78         //parallel stream
  79         long streamParallel = StreamSupport.stream(MemorySegment.spliterator(segment, layout), true)
  80                 .reduce(0L, TestSpliterator::sumSingle, Long::sum);
  81         assertEquals(streamParallel, expected);
  82         segment.close();
  83     }
  84 
  85     public void testSumSameThread() {
  86         SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT);
  87 
  88         //setup
  89         MemorySegment segment = MemorySegment.allocateNative(layout);
  90         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
  91             INT_HANDLE.set(segment.baseAddress(), (long) i, i);
  92         }
  93         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
  94 
  95         //check that a segment w/o ACQUIRE access mode can still be used from same thread
  96         AtomicLong spliteratorSum = new AtomicLong();
  97         spliterator(segment.withAccessModes(MemorySegment.READ), layout)
  98                 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s)));
  99         assertEquals(spliteratorSum.get(), expected);
 100     }
 101 
 102     static long sumSingle(long acc, MemorySegment segment) {
 103         return acc + (int)INT_HANDLE.get(segment.baseAddress(), 0L);
 104     }
 105 
 106     static long sum(long start, MemorySegment segment) {
 107         long sum = start;
 108         MemoryAddress base = segment.baseAddress();
 109         int length = (int)segment.byteSize();
 110         for (int i = 0 ; i < length / CARRIER_SIZE ; i++) {
 111             sum += (int)INT_HANDLE.get(base, (long)i);
 112         }
 113         return sum;
 114     }
 115 
 116     static class SumSegmentCounted extends CountedCompleter<Long> {
 117 
 118         final long threshold;
 119         long localSum = 0;
 120         List<SumSegmentCounted> children = new LinkedList<>();
 121 
 122         private Spliterator<MemorySegment> segmentSplitter;
 123 
 124         SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) {
 125             super(parent);
 126             this.segmentSplitter = segmentSplitter;
 127             this.threshold = threshold;
 128         }
 129 
 130         @Override
 131         public void compute() {
 132             Spliterator<MemorySegment> sub;
 133             while (segmentSplitter.estimateSize() > threshold &&
 134                     (sub = segmentSplitter.trySplit()) != null) {
 135                 addToPendingCount(1);
 136                 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold);
 137                 children.add(child);
 138                 child.fork();
 139             }
 140             segmentSplitter.forEachRemaining(slice -> {
 141                 localSum += sumSingle(0, slice);
 142             });
 143             tryComplete();
 144         }
 145 
 146         @Override
 147         public Long getRawResult() {
 148             long sum = localSum;
 149             for (SumSegmentCounted c : children) {
 150                 sum += c.getRawResult();
 151             }
 152             return sum;
 153         }
 154      }
 155 
 156     static class SumSegmentRecursive extends RecursiveTask<Long> {
 157 
 158         final long threshold;
 159         private final Spliterator<MemorySegment> splitter;
 160         private long result;
 161 
 162         SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) {
 163             this.splitter = splitter;
 164             this.threshold = threshold;
 165         }
 166 
 167         @Override
 168         protected Long compute() {
 169             if (splitter.estimateSize() > threshold) {
 170                 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold);
 171                 sub.fork();
 172                 return compute() + sub.join();
 173             } else {
 174                 splitter.forEachRemaining(slice -> {
 175                     result += sumSingle(0, slice);
 176                 });
 177                 return result;
 178             }
 179         }
 180     }
 181 
 182     @DataProvider(name = "splits")
 183     public Object[][] splits() {
 184         return new Object[][] {
 185                 { 10, 1 },
 186                 { 100, 1 },
 187                 { 1000, 1 },
 188                 { 10000, 1 },
 189                 { 10, 10 },
 190                 { 100, 10 },
 191                 { 1000, 10 },
 192                 { 10000, 10 },
 193                 { 10, 100 },
 194                 { 100, 100 },
 195                 { 1000, 100 },
 196                 { 10000, 100 },
 197                 { 10, 1000 },
 198                 { 100, 1000 },
 199                 { 1000, 1000 },
 200                 { 10000, 1000 },
 201                 { 10, 10000 },
 202                 { 100, 10000 },
 203                 { 1000, 10000 },
 204                 { 10000, 10000 },
 205         };
 206     }
 207 
 208     static final int ALL_ACCESS_MODES = READ | WRITE | CLOSE | ACQUIRE | HANDOFF;
 209 
 210     @DataProvider(name = "accessScenarios")
 211     public Object[][] accessScenarios() {
 212         SequenceLayout layout = MemoryLayout.ofSequence(16, MemoryLayouts.JAVA_INT);
 213         var mallocSegment = MemorySegment.allocateNative(layout);
 214 
 215         Map<Supplier<Spliterator<MemorySegment>>,Integer> l = Map.of(
 216             () -> spliterator(mallocSegment.withAccessModes(ALL_ACCESS_MODES), layout), ALL_ACCESS_MODES,
 217             () -> spliterator(mallocSegment.withAccessModes(0), layout), 0,
 218             () -> spliterator(mallocSegment.withAccessModes(READ), layout), READ,
 219             () -> spliterator(mallocSegment.withAccessModes(CLOSE), layout), 0,
 220             () -> spliterator(mallocSegment.withAccessModes(READ|WRITE), layout), READ|WRITE,
 221             () -> spliterator(mallocSegment.withAccessModes(READ|WRITE|ACQUIRE), layout), READ|WRITE|ACQUIRE,
 222             () -> spliterator(mallocSegment.withAccessModes(READ|WRITE|ACQUIRE|HANDOFF), layout), READ|WRITE|ACQUIRE|HANDOFF
 223 
 224         );
 225         return l.entrySet().stream().map(e -> new Object[] { e.getKey(), e.getValue() }).toArray(Object[][]::new);
 226     }
 227 
 228     static Consumer<MemorySegment> assertAccessModes(int accessModes) {
 229         return segment -> {
 230             assertTrue(segment.hasAccessModes(accessModes & ~CLOSE));
 231             assertEquals(segment.accessModes(), accessModes & ~CLOSE);
 232         };
 233     }
 234 
 235     @Test(dataProvider = "accessScenarios")
 236     public void testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier,
 237                                 int expectedAccessModes) {
 238         Spliterator<MemorySegment> spliterator = spliteratorSupplier.get();
 239         spliterator.forEachRemaining(assertAccessModes(expectedAccessModes));
 240 
 241         spliterator = spliteratorSupplier.get();
 242         do { } while (spliterator.tryAdvance(assertAccessModes(expectedAccessModes)));
 243 
 244         splitOrConsume(spliteratorSupplier.get(), assertAccessModes(expectedAccessModes));
 245     }
 246 
 247     static void splitOrConsume(Spliterator<MemorySegment> spliterator,
 248                                Consumer<MemorySegment> consumer) {
 249         var s1 = spliterator.trySplit();
 250         if (s1 != null) {
 251             splitOrConsume(s1, consumer);
 252             splitOrConsume(spliterator, consumer);
 253         } else {
 254             spliterator.forEachRemaining(consumer);
 255         }
 256     }
 257 }