1 /*
   2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @run testng TestSpliterator
  27  */
  28 
  29 import jdk.incubator.foreign.MemoryAddress;
  30 import jdk.incubator.foreign.MemoryLayout;
  31 import jdk.incubator.foreign.MemoryLayouts;
  32 import jdk.incubator.foreign.MemorySegment;
  33 import jdk.incubator.foreign.SequenceLayout;
  34 
  35 import java.lang.invoke.VarHandle;
  36 import java.util.LinkedList;
  37 import java.util.List;
  38 import java.util.Spliterator;
  39 import java.util.concurrent.CountedCompleter;
  40 import java.util.concurrent.RecursiveTask;
  41 import java.util.concurrent.atomic.AtomicLong;
  42 import java.util.stream.LongStream;
  43 import java.util.stream.StreamSupport;
  44 
  45 import org.testng.annotations.*;
  46 import static org.testng.Assert.*;
  47 
  48 public class TestSpliterator {
  49 
  50     static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT)
  51             .varHandle(int.class, MemoryLayout.PathElement.sequenceElement());
  52 
  53     final static int CARRIER_SIZE = 4;
  54 
  55     @Test(dataProvider = "splits")
  56     public void testSum(int size, int threshold) {
  57         SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT);
  58 
  59         //setup
  60         MemorySegment segment = MemorySegment.allocateNative(layout);
  61         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
  62             INT_HANDLE.set(segment.baseAddress(), (long) i, i);
  63         }
  64         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
  65         //serial
  66         long serial = sum(0, segment);
  67         assertEquals(serial, expected);
  68         //parallel counted completer
  69         long parallelCounted = new SumSegmentCounted(null, MemorySegment.spliterator(segment, layout), threshold).invoke();
  70         assertEquals(parallelCounted, expected);
  71         //parallel recursive action
  72         long parallelRecursive = new SumSegmentRecursive(MemorySegment.spliterator(segment, layout), threshold).invoke();
  73         assertEquals(parallelRecursive, expected);
  74         //parallel stream
  75         long streamParallel = StreamSupport.stream(MemorySegment.spliterator(segment, layout), true)
  76                 .reduce(0L, TestSpliterator::sumSingle, Long::sum);
  77         assertEquals(streamParallel, expected);
  78         segment.close();
  79     }
  80 
  81     public void testSumSameThread() {
  82         SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT);
  83 
  84         //setup
  85         MemorySegment segment = MemorySegment.allocateNative(layout);
  86         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
  87             INT_HANDLE.set(segment.baseAddress(), (long) i, i);
  88         }
  89         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
  90 
  91         //check that a segment w/o ACQUIRE access mode can still be used from same thread
  92         AtomicLong spliteratorSum = new AtomicLong();
  93         MemorySegment.spliterator(segment.withAccessModes(MemorySegment.READ), layout)
  94                 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s)));
  95         assertEquals(spliteratorSum.get(), expected);
  96     }
  97 
  98     static long sumSingle(long acc, MemorySegment segment) {
  99         return acc + (int)INT_HANDLE.get(segment.baseAddress(), 0L);
 100     }
 101 
 102     static long sum(long start, MemorySegment segment) {
 103         long sum = start;
 104         MemoryAddress base = segment.baseAddress();
 105         int length = (int)segment.byteSize();
 106         for (int i = 0 ; i < length / CARRIER_SIZE ; i++) {
 107             sum += (int)INT_HANDLE.get(base, (long)i);
 108         }
 109         return sum;
 110     }
 111 
 112     static class SumSegmentCounted extends CountedCompleter<Long> {
 113 
 114         final long threshold;
 115         long localSum = 0;
 116         List<SumSegmentCounted> children = new LinkedList<>();
 117 
 118         private Spliterator<MemorySegment> segmentSplitter;
 119 
 120         SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) {
 121             super(parent);
 122             this.segmentSplitter = segmentSplitter;
 123             this.threshold = threshold;
 124         }
 125 
 126         @Override
 127         public void compute() {
 128             Spliterator<MemorySegment> sub;
 129             while (segmentSplitter.estimateSize() > threshold &&
 130                     (sub = segmentSplitter.trySplit()) != null) {
 131                 addToPendingCount(1);
 132                 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold);
 133                 children.add(child);
 134                 child.fork();
 135             }
 136             segmentSplitter.forEachRemaining(slice -> {
 137                 localSum += sumSingle(0, slice);
 138             });
 139             tryComplete();
 140         }
 141 
 142         @Override
 143         public Long getRawResult() {
 144             long sum = localSum;
 145             for (SumSegmentCounted c : children) {
 146                 sum += c.getRawResult();
 147             }
 148             return sum;
 149         }
 150      }
 151 
 152     static class SumSegmentRecursive extends RecursiveTask<Long> {
 153 
 154         final long threshold;
 155         private final Spliterator<MemorySegment> splitter;
 156         private long result;
 157 
 158         SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) {
 159             this.splitter = splitter;
 160             this.threshold = threshold;
 161         }
 162 
 163         @Override
 164         protected Long compute() {
 165             if (splitter.estimateSize() > threshold) {
 166                 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold);
 167                 sub.fork();
 168                 return compute() + sub.join();
 169             } else {
 170                 splitter.forEachRemaining(slice -> {
 171                     result += sumSingle(0, slice);
 172                 });
 173                 return result;
 174             }
 175         }
 176     }
 177 
 178     @DataProvider(name = "splits")
 179     public Object[][] splits() {
 180         return new Object[][] {
 181                 { 10, 1 },
 182                 { 100, 1 },
 183                 { 1000, 1 },
 184                 { 10000, 1 },
 185                 { 10, 10 },
 186                 { 100, 10 },
 187                 { 1000, 10 },
 188                 { 10000, 10 },
 189                 { 10, 100 },
 190                 { 100, 100 },
 191                 { 1000, 100 },
 192                 { 10000, 100 },
 193                 { 10, 1000 },
 194                 { 100, 1000 },
 195                 { 1000, 1000 },
 196                 { 10000, 1000 },
 197                 { 10, 10000 },
 198                 { 100, 10000 },
 199                 { 1000, 10000 },
 200                 { 10000, 10000 },
 201         };
 202     }
 203 }