1 /* 2 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @run testng TestSpliterator 27 */ 28 29 import jdk.incubator.foreign.MemoryAddress; 30 import jdk.incubator.foreign.MemoryLayout; 31 import jdk.incubator.foreign.MemoryLayouts; 32 import jdk.incubator.foreign.MemorySegment; 33 import jdk.incubator.foreign.SequenceLayout; 34 35 import java.lang.invoke.VarHandle; 36 import java.util.LinkedList; 37 import java.util.List; 38 import java.util.Map; 39 import java.util.Spliterator; 40 import java.util.concurrent.CountedCompleter; 41 import java.util.concurrent.RecursiveTask; 42 import java.util.concurrent.atomic.AtomicLong; 43 import java.util.function.Consumer; 44 import java.util.function.Supplier; 45 import java.util.stream.LongStream; 46 import java.util.stream.StreamSupport; 47 48 import org.testng.annotations.*; 49 import static jdk.incubator.foreign.MemorySegment.*; 50 import static org.testng.Assert.*; 51 52 public class TestSpliterator { 53 54 static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT) 55 .varHandle(int.class, MemoryLayout.PathElement.sequenceElement()); 56 57 final static int CARRIER_SIZE = 4; 58 59 @Test(dataProvider = "splits") 60 public void testSum(int size, int threshold) { 61 SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT); 62 63 //setup 64 MemorySegment segment = MemorySegment.allocateNative(layout); 65 for (int i = 0; i < layout.elementCount().getAsLong(); i++) { 66 INT_HANDLE.set(segment.baseAddress(), (long) i, i); 67 } 68 long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum(); 69 //serial 70 long serial = sum(0, segment); 71 assertEquals(serial, expected); 72 //parallel counted completer 73 long parallelCounted = new SumSegmentCounted(null, MemorySegment.spliterator(segment, layout), threshold).invoke(); 74 assertEquals(parallelCounted, expected); 75 //parallel recursive action 76 long parallelRecursive = new SumSegmentRecursive(MemorySegment.spliterator(segment, layout), threshold).invoke(); 77 assertEquals(parallelRecursive, expected); 78 //parallel stream 79 long streamParallel = StreamSupport.stream(MemorySegment.spliterator(segment, layout), true) 80 .reduce(0L, TestSpliterator::sumSingle, Long::sum); 81 assertEquals(streamParallel, expected); 82 segment.close(); 83 } 84 85 public void testSumSameThread() { 86 SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT); 87 88 //setup 89 MemorySegment segment = MemorySegment.allocateNative(layout); 90 for (int i = 0; i < layout.elementCount().getAsLong(); i++) { 91 INT_HANDLE.set(segment.baseAddress(), (long) i, i); 92 } 93 long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum(); 94 95 //check that a segment w/o ACQUIRE access mode can still be used from same thread 96 AtomicLong spliteratorSum = new AtomicLong(); 97 spliterator(segment.withAccessModes(MemorySegment.READ), layout) 98 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s))); 99 assertEquals(spliteratorSum.get(), expected); 100 } 101 102 static long sumSingle(long acc, MemorySegment segment) { 103 return acc + (int)INT_HANDLE.get(segment.baseAddress(), 0L); 104 } 105 106 static long sum(long start, MemorySegment segment) { 107 long sum = start; 108 MemoryAddress base = segment.baseAddress(); 109 int length = (int)segment.byteSize(); 110 for (int i = 0 ; i < length / CARRIER_SIZE ; i++) { 111 sum += (int)INT_HANDLE.get(base, (long)i); 112 } 113 return sum; 114 } 115 116 static class SumSegmentCounted extends CountedCompleter<Long> { 117 118 final long threshold; 119 long localSum = 0; 120 List<SumSegmentCounted> children = new LinkedList<>(); 121 122 private Spliterator<MemorySegment> segmentSplitter; 123 124 SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) { 125 super(parent); 126 this.segmentSplitter = segmentSplitter; 127 this.threshold = threshold; 128 } 129 130 @Override 131 public void compute() { 132 Spliterator<MemorySegment> sub; 133 while (segmentSplitter.estimateSize() > threshold && 134 (sub = segmentSplitter.trySplit()) != null) { 135 addToPendingCount(1); 136 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold); 137 children.add(child); 138 child.fork(); 139 } 140 segmentSplitter.forEachRemaining(slice -> { 141 localSum += sumSingle(0, slice); 142 }); 143 tryComplete(); 144 } 145 146 @Override 147 public Long getRawResult() { 148 long sum = localSum; 149 for (SumSegmentCounted c : children) { 150 sum += c.getRawResult(); 151 } 152 return sum; 153 } 154 } 155 156 static class SumSegmentRecursive extends RecursiveTask<Long> { 157 158 final long threshold; 159 private final Spliterator<MemorySegment> splitter; 160 private long result; 161 162 SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) { 163 this.splitter = splitter; 164 this.threshold = threshold; 165 } 166 167 @Override 168 protected Long compute() { 169 if (splitter.estimateSize() > threshold) { 170 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold); 171 sub.fork(); 172 return compute() + sub.join(); 173 } else { 174 splitter.forEachRemaining(slice -> { 175 result += sumSingle(0, slice); 176 }); 177 return result; 178 } 179 } 180 } 181 182 @DataProvider(name = "splits") 183 public Object[][] splits() { 184 return new Object[][] { 185 { 10, 1 }, 186 { 100, 1 }, 187 { 1000, 1 }, 188 { 10000, 1 }, 189 { 10, 10 }, 190 { 100, 10 }, 191 { 1000, 10 }, 192 { 10000, 10 }, 193 { 10, 100 }, 194 { 100, 100 }, 195 { 1000, 100 }, 196 { 10000, 100 }, 197 { 10, 1000 }, 198 { 100, 1000 }, 199 { 1000, 1000 }, 200 { 10000, 1000 }, 201 { 10, 10000 }, 202 { 100, 10000 }, 203 { 1000, 10000 }, 204 { 10000, 10000 }, 205 }; 206 } 207 208 static final int ALL_ACCESS_MODES = READ | WRITE | CLOSE | ACQUIRE | HANDOFF; 209 210 @DataProvider(name = "accessScenarios") 211 public Object[][] accessScenarios() { 212 SequenceLayout layout = MemoryLayout.ofSequence(16, MemoryLayouts.JAVA_INT); 213 var mallocSegment = MemorySegment.allocateNative(layout); 214 215 Map<Supplier<Spliterator<MemorySegment>>,Integer> l = Map.of( 216 () -> spliterator(mallocSegment.withAccessModes(ALL_ACCESS_MODES), layout), ALL_ACCESS_MODES, 217 () -> spliterator(mallocSegment.withAccessModes(0), layout), 0, 218 () -> spliterator(mallocSegment.withAccessModes(READ), layout), READ, 219 () -> spliterator(mallocSegment.withAccessModes(CLOSE), layout), 0, 220 () -> spliterator(mallocSegment.withAccessModes(READ|WRITE), layout), READ|WRITE, 221 () -> spliterator(mallocSegment.withAccessModes(READ|WRITE|ACQUIRE), layout), READ|WRITE|ACQUIRE, 222 () -> spliterator(mallocSegment.withAccessModes(READ|WRITE|ACQUIRE|HANDOFF), layout), READ|WRITE|ACQUIRE|HANDOFF 223 224 ); 225 return l.entrySet().stream().map(e -> new Object[] { e.getKey(), e.getValue() }).toArray(Object[][]::new); 226 } 227 228 static Consumer<MemorySegment> assertAccessModes(int accessModes) { 229 return segment -> { 230 assertTrue(segment.hasAccessModes(accessModes & ~CLOSE)); 231 assertEquals(segment.accessModes(), accessModes & ~CLOSE); 232 }; 233 } 234 235 @Test(dataProvider = "accessScenarios") 236 public void testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier, 237 int expectedAccessModes) { 238 Spliterator<MemorySegment> spliterator = spliteratorSupplier.get(); 239 spliterator.forEachRemaining(assertAccessModes(expectedAccessModes)); 240 241 spliterator = spliteratorSupplier.get(); 242 do { } while (spliterator.tryAdvance(assertAccessModes(expectedAccessModes))); 243 244 splitOrConsume(spliteratorSupplier.get(), assertAccessModes(expectedAccessModes)); 245 } 246 247 static void splitOrConsume(Spliterator<MemorySegment> spliterator, 248 Consumer<MemorySegment> consumer) { 249 var s1 = spliterator.trySplit(); 250 if (s1 != null) { 251 splitOrConsume(s1, consumer); 252 splitOrConsume(spliterator, consumer); 253 } else { 254 spliterator.forEachRemaining(consumer); 255 } 256 } 257 }