package jdk.test; import jdk.internal.misc.Unsafe; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Group; import org.openjdk.jmh.annotations.Level; import org.openjdk.jmh.annotations.Measurement; import org.openjdk.jmh.annotations.Mode; import org.openjdk.jmh.annotations.OutputTimeUnit; import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.infra.Blackhole; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.concurrent.TimeUnit; import java.util.function.IntUnaryOperator; @BenchmarkMode(Mode.AverageTime) @Fork(value = 1, warmups = 0, jvmArgsAppend = {"--add-exports", "java.base/jdk.internal.misc=ALL-UNNAMED"}) @Warmup(iterations = 10) @Measurement(iterations = 20) @OutputTimeUnit(TimeUnit.NANOSECONDS) @State(Scope.Benchmark) public class GetAndUpdateBench3 { // allocate a direct buffer that starts at the start of cache line static final ByteBuffer BB = ByteBuffer.allocateDirect(128) .alignedSlice(64) .order(ByteOrder.nativeOrder()); static final Unsafe U = Unsafe.getUnsafe(); static final long BB_ADDRESS; static { try { BB_ADDRESS = U.getLong(BB, U.objectFieldOffset(Buffer.class.getDeclaredField("address"))); } catch (NoSuchFieldException e) { throw new Error(e); } } @Param({"1", "10", "20", "50", "100"}) public long updateFnCpu; IntUnaryOperator updateFn; @Setup(Level.Trial) public void setup() { long tokens = updateFnCpu; updateFn = i -> { Blackhole.consumeCPU(tokens); return i + 1; }; } @Benchmark @Group("dflt") public int getAndUpdate1_dflt() { return getAndUpdate_dflt(0); } @Benchmark @Group("dflt") public int getAndUpdate2_dflt() { return getAndUpdate_dflt(Integer.BYTES); } private int getAndUpdate_dflt(int i) { int prev = U.getIntVolatile(null, BB_ADDRESS + i), next = 0; for (boolean haveNext = false; ; ) { if (!haveNext) next = updateFn.applyAsInt(prev); if (U.weakCompareAndSwapIntVolatile(null, BB_ADDRESS + i, prev, next)) return prev; haveNext = (prev == (prev = U.getIntVolatile(null, BB_ADDRESS + i))); } } @Benchmark @Group("martin") public int getAndUpdate1_martin() { return getAndUpdate_martin(0); } @Benchmark @Group("martin") public int getAndUpdate2_martin() { return getAndUpdate_martin(Integer.BYTES); } private int getAndUpdate_martin(int i) { for (int prev = U.getIntVolatile(null, BB_ADDRESS + i); ; ) { int next = updateFn.applyAsInt(prev); do { if (U.weakCompareAndSwapIntVolatile(null, BB_ADDRESS + i, prev, next)) return prev; } while (prev == (prev = U.getIntVolatile(null, BB_ADDRESS + i))); } } @Benchmark @Group("shade") public int getAndUpdate1_shade() { return getAndUpdate_shade(0); } @Benchmark @Group("shade") public int getAndUpdate2_shade() { return getAndUpdate_shade(Integer.BYTES); } private int getAndUpdate_shade(int i) { int prev, next; do { prev = U.getIntVolatile(null, BB_ADDRESS + i); next = updateFn.applyAsInt(prev); } while (!U.weakCompareAndSwapIntVolatile(null, BB_ADDRESS + i, prev, next)); return prev; } @Benchmark @Group("strong") public int getAndUpdate1_strong() { return getAndUpdate_strong(0); } @Benchmark @Group("strong") public int getAndUpdate2_strong() { return getAndUpdate_strong(Integer.BYTES); } private int getAndUpdate_strong(int i) { for (; ; ) { int prev = U.getIntVolatile(null, BB_ADDRESS + i); int next = updateFn.applyAsInt(prev); if (U.compareAndSwapInt(null, BB_ADDRESS + i, prev, next)) return prev; } } }