1 /* 2 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @run testng TestSpliterator 27 */ 28 29 import jdk.incubator.foreign.MemoryAddress; 30 import jdk.incubator.foreign.MemoryLayout; 31 import jdk.incubator.foreign.MemoryLayouts; 32 import jdk.incubator.foreign.MemorySegment; 33 import jdk.incubator.foreign.SequenceLayout; 34 35 import java.lang.invoke.VarHandle; 36 import java.util.LinkedList; 37 import java.util.List; 38 import java.util.Spliterator; 39 import java.util.concurrent.CountedCompleter; 40 import java.util.concurrent.RecursiveTask; 41 import java.util.concurrent.atomic.AtomicLong; 42 import java.util.stream.LongStream; 43 import java.util.stream.StreamSupport; 44 45 import org.testng.annotations.*; 46 import static org.testng.Assert.*; 47 48 public class TestSpliterator { 49 50 static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT) 51 .varHandle(int.class, MemoryLayout.PathElement.sequenceElement()); 52 53 final static int CARRIER_SIZE = 4; 54 55 @Test(dataProvider = "splits") 56 public void testSum(int size, int threshold) { 57 SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT); 58 59 //setup 60 MemorySegment segment = MemorySegment.allocateNative(layout); 61 for (int i = 0; i < layout.elementCount().getAsLong(); i++) { 62 INT_HANDLE.set(segment.baseAddress(), (long) i, i); 63 } 64 long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum(); 65 //serial 66 long serial = sum(0, segment); 67 assertEquals(serial, expected); 68 //parallel counted completer 69 long parallelCounted = new SumSegmentCounted(null, MemorySegment.spliterator(segment, layout), threshold).invoke(); 70 assertEquals(parallelCounted, expected); 71 //parallel recursive action 72 long parallelRecursive = new SumSegmentRecursive(MemorySegment.spliterator(segment, layout), threshold).invoke(); 73 assertEquals(parallelRecursive, expected); 74 //parallel stream 75 long streamParallel = StreamSupport.stream(MemorySegment.spliterator(segment, layout), true) 76 .reduce(0L, TestSpliterator::sumSingle, Long::sum); 77 assertEquals(streamParallel, expected); 78 segment.close(); 79 } 80 81 public void testSumSameThread() { 82 SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT); 83 84 //setup 85 MemorySegment segment = MemorySegment.allocateNative(layout); 86 for (int i = 0; i < layout.elementCount().getAsLong(); i++) { 87 INT_HANDLE.set(segment.baseAddress(), (long) i, i); 88 } 89 long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum(); 90 91 //check that a segment w/o ACQUIRE access mode can still be used from same thread 92 AtomicLong spliteratorSum = new AtomicLong(); 93 MemorySegment.spliterator(segment.withAccessModes(MemorySegment.READ), layout) 94 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s))); 95 assertEquals(spliteratorSum.get(), expected); 96 } 97 98 static long sumSingle(long acc, MemorySegment segment) { 99 return acc + (int)INT_HANDLE.get(segment.baseAddress(), 0L); 100 } 101 102 static long sum(long start, MemorySegment segment) { 103 long sum = start; 104 MemoryAddress base = segment.baseAddress(); 105 int length = (int)segment.byteSize(); 106 for (int i = 0 ; i < length / CARRIER_SIZE ; i++) { 107 sum += (int)INT_HANDLE.get(base, (long)i); 108 } 109 return sum; 110 } 111 112 static class SumSegmentCounted extends CountedCompleter<Long> { 113 114 final long threshold; 115 long localSum = 0; 116 List<SumSegmentCounted> children = new LinkedList<>(); 117 118 private Spliterator<MemorySegment> segmentSplitter; 119 120 SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) { 121 super(parent); 122 this.segmentSplitter = segmentSplitter; 123 this.threshold = threshold; 124 } 125 126 @Override 127 public void compute() { 128 Spliterator<MemorySegment> sub; 129 while (segmentSplitter.estimateSize() > threshold && 130 (sub = segmentSplitter.trySplit()) != null) { 131 addToPendingCount(1); 132 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold); 133 children.add(child); 134 child.fork(); 135 } 136 segmentSplitter.forEachRemaining(slice -> { 137 localSum += sumSingle(0, slice); 138 }); 139 tryComplete(); 140 } 141 142 @Override 143 public Long getRawResult() { 144 long sum = localSum; 145 for (SumSegmentCounted c : children) { 146 sum += c.getRawResult(); 147 } 148 return sum; 149 } 150 } 151 152 static class SumSegmentRecursive extends RecursiveTask<Long> { 153 154 final long threshold; 155 private final Spliterator<MemorySegment> splitter; 156 private long result; 157 158 SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) { 159 this.splitter = splitter; 160 this.threshold = threshold; 161 } 162 163 @Override 164 protected Long compute() { 165 if (splitter.estimateSize() > threshold) { 166 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold); 167 sub.fork(); 168 return compute() + sub.join(); 169 } else { 170 splitter.forEachRemaining(slice -> { 171 result += sumSingle(0, slice); 172 }); 173 return result; 174 } 175 } 176 } 177 178 @DataProvider(name = "splits") 179 public Object[][] splits() { 180 return new Object[][] { 181 { 10, 1 }, 182 { 100, 1 }, 183 { 1000, 1 }, 184 { 10000, 1 }, 185 { 10, 10 }, 186 { 100, 10 }, 187 { 1000, 10 }, 188 { 10000, 10 }, 189 { 10, 100 }, 190 { 100, 100 }, 191 { 1000, 100 }, 192 { 10000, 100 }, 193 { 10, 1000 }, 194 { 100, 1000 }, 195 { 1000, 1000 }, 196 { 10000, 1000 }, 197 { 10, 10000 }, 198 { 100, 10000 }, 199 { 1000, 10000 }, 200 { 10000, 10000 }, 201 }; 202 } 203 }