import jdk.incubator.vector.*; import jdk.internal.vm.annotation.ForceInline; import org.testng.Assert; import org.testng.annotations.Test; import org.testng.annotations.DataProvider; import java.util.List; import java.util.function.IntFunction; /** * @test * @modules jdk.incubator.vector * @modules java.base/jdk.internal.vm.annotation * @run testng VectorReshapeTests */ @Test public class VectorReshapeTests { static final int NUM_ITER = 20; static final IntVector.IntSpecies ispec64 = (IntVector.IntSpecies) Vector.speciesInstance(Integer.class, Shapes.S_64_BIT); static final FloatVector.FloatSpecies fspec64 = (FloatVector.FloatSpecies) Vector.speciesInstance(Float.class, Shapes.S_64_BIT); static final LongVector.LongSpecies lspec64 = (LongVector.LongSpecies) Vector.speciesInstance(Long.class, Shapes.S_64_BIT); static final DoubleVector.DoubleSpecies dspec64 = (DoubleVector.DoubleSpecies) Vector.speciesInstance(Double.class, Shapes.S_64_BIT); static final ByteVector.ByteSpecies bspec64 = (ByteVector.ByteSpecies) Vector.speciesInstance(Byte.class, Shapes.S_64_BIT); static final ShortVector.ShortSpecies sspec64 = (ShortVector.ShortSpecies) Vector.speciesInstance(Short.class, Shapes.S_64_BIT); static final IntVector.IntSpecies ispec128 = (IntVector.IntSpecies) Vector.speciesInstance(Integer.class, Shapes.S_128_BIT); static final FloatVector.FloatSpecies fspec128 = (FloatVector.FloatSpecies) Vector.speciesInstance(Float.class, Shapes.S_128_BIT); static final LongVector.LongSpecies lspec128 = (LongVector.LongSpecies) Vector.speciesInstance(Long.class, Shapes.S_128_BIT); static final DoubleVector.DoubleSpecies dspec128 = (DoubleVector.DoubleSpecies) Vector.speciesInstance(Double.class, Shapes.S_128_BIT); static final ByteVector.ByteSpecies bspec128 = (ByteVector.ByteSpecies) Vector.speciesInstance(Byte.class, Shapes.S_128_BIT); static final ShortVector.ShortSpecies sspec128 = (ShortVector.ShortSpecies) Vector.speciesInstance(Short.class, Shapes.S_128_BIT); static final IntVector.IntSpecies ispec256 = (IntVector.IntSpecies) Vector.speciesInstance(Integer.class, Shapes.S_256_BIT); static final FloatVector.FloatSpecies fspec256 = (FloatVector.FloatSpecies) Vector.speciesInstance(Float.class, Shapes.S_256_BIT); static final LongVector.LongSpecies lspec256 = (LongVector.LongSpecies) Vector.speciesInstance(Long.class, Shapes.S_256_BIT); static final DoubleVector.DoubleSpecies dspec256 = (DoubleVector.DoubleSpecies) Vector.speciesInstance(Double.class, Shapes.S_256_BIT); static final ByteVector.ByteSpecies bspec256 = (ByteVector.ByteSpecies) Vector.speciesInstance(Byte.class, Shapes.S_256_BIT); static final ShortVector.ShortSpecies sspec256 = (ShortVector.ShortSpecies) Vector.speciesInstance(Short.class, Shapes.S_256_BIT); static final IntVector.IntSpecies ispec512 = (IntVector.IntSpecies) Vector.speciesInstance(Integer.class, Shapes.S_512_BIT); static final FloatVector.FloatSpecies fspec512 = (FloatVector.FloatSpecies) Vector.speciesInstance(Float.class, Shapes.S_512_BIT); static final LongVector.LongSpecies lspec512 = (LongVector.LongSpecies) Vector.speciesInstance(Long.class, Shapes.S_512_BIT); static final DoubleVector.DoubleSpecies dspec512 = (DoubleVector.DoubleSpecies) Vector.speciesInstance(Double.class, Shapes.S_512_BIT); static final ByteVector.ByteSpecies bspec512 = (ByteVector.ByteSpecies) Vector.speciesInstance(Byte.class, Shapes.S_512_BIT); static final ShortVector.ShortSpecies sspec512 = (ShortVector.ShortSpecies) Vector.speciesInstance(Short.class, Shapes.S_512_BIT); static IntFunction withToString(String s, IntFunction f) { return new IntFunction() { @Override public T apply(int v) { return f.apply(v); } @Override public String toString() { return s; } }; } interface ToByteF { byte apply(int i); } static byte[] fill(int s , ToByteF f) { return fill(new byte[s], f); } static byte[] fill(byte[] a, ToByteF f) { for (int i = 0; i < a.length; i++) { a[i] = f.apply(i); } return a; } static final List> BYTE_GENERATORS = List.of( withToString("byte[i * 5]", (int s) -> { return fill(s * 1000, i -> (byte)(i * 5)); }), withToString("byte[i + 1]", (int s) -> { return fill(s * 1000, i -> (((byte)(i + 1) == 0) ? 1 : (byte)(i + 1))); }) ); @DataProvider public Object[][] byteUnaryOpProvider() { return BYTE_GENERATORS.stream(). map(f -> new Object[]{f}). toArray(Object[][]::new); } @ForceInline static void testResize(Vector.Species a, Vector.Species b, byte[] input) { int spec_a_num_bytes = a.bitSize() / Byte.SIZE; int spec_b_num_bytes = b.bitSize() / Byte.SIZE; // Set a loop bound so that a load will not go out of bounds based on vector size. int loop_bound = Math.min(input.length - spec_a_num_bytes, input.length - spec_b_num_bytes); // Create arrays being used for storing. Use largest size. int max_num_bytes = Math.max(spec_a_num_bytes, spec_b_num_bytes); int min_num_bytes = Math.min(spec_a_num_bytes, spec_b_num_bytes); byte[] actual = new byte[max_num_bytes]; byte[] expected = new byte[max_num_bytes]; for (int i = 0; i < loop_bound; i++) { Vector av = a.fromByteArray(input, i); Vector bv = av.resize(b); bv.intoByteArray(actual, 0); // Compute expected. for (int j = 0; j < max_num_bytes; j++) { if (j < min_num_bytes) { expected[j] = input[i+j]; } else { expected[j] = 0; } } Assert.assertEquals(expected, actual); } } @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) static void testResizeByte(IntFunction fa) { byte[] barr = fa.apply(bspec64.elementSize()); for (int i = 0; i < NUM_ITER; i++) { testResize(bspec64, bspec64, barr); testResize(bspec64, bspec128, barr); testResize(bspec64, bspec256, barr); testResize(bspec64, bspec512, barr); testResize(bspec128, bspec64, barr); testResize(bspec128, bspec128, barr); testResize(bspec128, bspec256, barr); testResize(bspec128, bspec512, barr); testResize(bspec256, bspec64, barr); testResize(bspec256, bspec128, barr); testResize(bspec256, bspec256, barr); testResize(bspec256, bspec512, barr); testResize(bspec512, bspec64, barr); testResize(bspec512, bspec128, barr); testResize(bspec512, bspec256, barr); testResize(bspec512, bspec512, barr); } } @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) static void testResizeShort(IntFunction fa) { byte[] barr = fa.apply(sspec64.elementSize()); for (int i = 0; i < NUM_ITER; i++) { testResize(sspec64, sspec64, barr); testResize(sspec64, sspec128, barr); testResize(sspec64, sspec256, barr); testResize(sspec64, sspec512, barr); testResize(sspec128, sspec64, barr); testResize(sspec128, sspec128, barr); testResize(sspec128, sspec256, barr); testResize(sspec128, sspec512, barr); testResize(sspec256, sspec64, barr); testResize(sspec256, sspec128, barr); testResize(sspec256, sspec256, barr); testResize(sspec256, sspec512, barr); testResize(sspec512, sspec64, barr); testResize(sspec512, sspec128, barr); testResize(sspec512, sspec256, barr); testResize(sspec512, sspec512, barr); } } @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) static void testResizeInt(IntFunction fa) { byte[] barr = fa.apply(ispec64.elementSize()); for (int i = 0; i < NUM_ITER; i++) { testResize(ispec64, ispec64, barr); testResize(ispec64, ispec128, barr); testResize(ispec64, ispec256, barr); testResize(ispec64, ispec512, barr); testResize(ispec128, ispec64, barr); testResize(ispec128, ispec128, barr); testResize(ispec128, ispec256, barr); testResize(ispec128, ispec512, barr); testResize(ispec256, ispec64, barr); testResize(ispec256, ispec128, barr); testResize(ispec256, ispec256, barr); testResize(ispec256, ispec512, barr); testResize(ispec512, ispec64, barr); testResize(ispec512, ispec128, barr); testResize(ispec512, ispec256, barr); testResize(ispec512, ispec512, barr); } } @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) static void testResizeLong(IntFunction fa) { byte[] barr = fa.apply(lspec64.elementSize()); for (int i = 0; i < NUM_ITER; i++) { testResize(lspec64, lspec64, barr); testResize(lspec64, lspec128, barr); testResize(lspec64, lspec256, barr); testResize(lspec64, lspec512, barr); testResize(lspec128, lspec64, barr); testResize(lspec128, lspec128, barr); testResize(lspec128, lspec256, barr); testResize(lspec128, lspec512, barr); testResize(lspec256, lspec64, barr); testResize(lspec256, lspec128, barr); testResize(lspec256, lspec256, barr); testResize(lspec256, lspec512, barr); testResize(lspec512, lspec64, barr); testResize(lspec512, lspec128, barr); testResize(lspec512, lspec256, barr); testResize(lspec512, lspec512, barr); } } @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) static void testResizeFloat(IntFunction fa) { byte[] barr = fa.apply(fspec64.elementSize()); for (int i = 0; i < NUM_ITER; i++) { testResize(fspec64, fspec64, barr); testResize(fspec64, fspec128, barr); testResize(fspec64, fspec256, barr); testResize(fspec64, fspec512, barr); testResize(fspec128, fspec64, barr); testResize(fspec128, fspec128, barr); testResize(fspec128, fspec256, barr); testResize(fspec128, fspec512, barr); testResize(fspec256, fspec64, barr); testResize(fspec256, fspec128, barr); testResize(fspec256, fspec256, barr); testResize(fspec256, fspec512, barr); testResize(fspec512, fspec64, barr); testResize(fspec512, fspec128, barr); testResize(fspec512, fspec256, barr); testResize(fspec512, fspec512, barr); } } @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) static void testResizeDouble(IntFunction fa) { byte[] barr = fa.apply(dspec64.elementSize()); for (int i = 0; i < NUM_ITER; i++) { testResize(dspec64, dspec64, barr); testResize(dspec64, dspec128, barr); testResize(dspec64, dspec256, barr); testResize(dspec64, dspec512, barr); testResize(dspec128, dspec64, barr); testResize(dspec128, dspec128, barr); testResize(dspec128, dspec256, barr); testResize(dspec128, dspec512, barr); testResize(dspec256, dspec64, barr); testResize(dspec256, dspec128, barr); testResize(dspec256, dspec256, barr); testResize(dspec256, dspec512, barr); testResize(dspec512, dspec64, barr); testResize(dspec512, dspec128, barr); testResize(dspec512, dspec256, barr); testResize(dspec512, dspec512, barr); } } }