1 import jdk.incubator.vector.*; 2 import jdk.internal.vm.annotation.ForceInline; 3 import org.testng.Assert; 4 import org.testng.annotations.Test; 5 import org.testng.annotations.DataProvider; 6 7 import java.util.List; 8 import java.util.function.IntFunction; 9 10 /** 11 * @test 12 * @modules jdk.incubator.vector 13 * @modules java.base/jdk.internal.vm.annotation 14 * @run testng VectorReshapeTests 15 */ 16 17 @Test 18 public class VectorReshapeTests { 19 static final int NUM_ITER = 20; 20 21 static final IntVector.IntSpecies<Shapes.S64Bit> ispec64 = (IntVector.IntSpecies<Shapes.S64Bit>) Vector.speciesInstance(Integer.class, Shapes.S_64_BIT); 22 static final FloatVector.FloatSpecies<Shapes.S64Bit> fspec64 = (FloatVector.FloatSpecies<Shapes.S64Bit>) Vector.speciesInstance(Float.class, Shapes.S_64_BIT); 23 static final LongVector.LongSpecies<Shapes.S64Bit> lspec64 = (LongVector.LongSpecies<Shapes.S64Bit>) Vector.speciesInstance(Long.class, Shapes.S_64_BIT); 24 static final DoubleVector.DoubleSpecies<Shapes.S64Bit> dspec64 = (DoubleVector.DoubleSpecies<Shapes.S64Bit>) Vector.speciesInstance(Double.class, Shapes.S_64_BIT); 25 static final ByteVector.ByteSpecies<Shapes.S64Bit> bspec64 = (ByteVector.ByteSpecies<Shapes.S64Bit>) Vector.speciesInstance(Byte.class, Shapes.S_64_BIT); 26 static final ShortVector.ShortSpecies<Shapes.S64Bit> sspec64 = (ShortVector.ShortSpecies<Shapes.S64Bit>) Vector.speciesInstance(Short.class, Shapes.S_64_BIT); 27 28 static final IntVector.IntSpecies<Shapes.S128Bit> ispec128 = (IntVector.IntSpecies<Shapes.S128Bit>) Vector.speciesInstance(Integer.class, Shapes.S_128_BIT); 29 static final FloatVector.FloatSpecies<Shapes.S128Bit> fspec128 = (FloatVector.FloatSpecies<Shapes.S128Bit>) Vector.speciesInstance(Float.class, Shapes.S_128_BIT); 30 static final LongVector.LongSpecies<Shapes.S128Bit> lspec128 = (LongVector.LongSpecies<Shapes.S128Bit>) Vector.speciesInstance(Long.class, Shapes.S_128_BIT); 31 static final DoubleVector.DoubleSpecies<Shapes.S128Bit> dspec128 = (DoubleVector.DoubleSpecies<Shapes.S128Bit>) Vector.speciesInstance(Double.class, Shapes.S_128_BIT); 32 static final ByteVector.ByteSpecies<Shapes.S128Bit> bspec128 = (ByteVector.ByteSpecies<Shapes.S128Bit>) Vector.speciesInstance(Byte.class, Shapes.S_128_BIT); 33 static final ShortVector.ShortSpecies<Shapes.S128Bit> sspec128 = (ShortVector.ShortSpecies<Shapes.S128Bit>) Vector.speciesInstance(Short.class, Shapes.S_128_BIT); 34 35 static final IntVector.IntSpecies<Shapes.S256Bit> ispec256 = (IntVector.IntSpecies<Shapes.S256Bit>) Vector.speciesInstance(Integer.class, Shapes.S_256_BIT); 36 static final FloatVector.FloatSpecies<Shapes.S256Bit> fspec256 = (FloatVector.FloatSpecies<Shapes.S256Bit>) Vector.speciesInstance(Float.class, Shapes.S_256_BIT); 37 static final LongVector.LongSpecies<Shapes.S256Bit> lspec256 = (LongVector.LongSpecies<Shapes.S256Bit>) Vector.speciesInstance(Long.class, Shapes.S_256_BIT); 38 static final DoubleVector.DoubleSpecies<Shapes.S256Bit> dspec256 = (DoubleVector.DoubleSpecies<Shapes.S256Bit>) Vector.speciesInstance(Double.class, Shapes.S_256_BIT); 39 static final ByteVector.ByteSpecies<Shapes.S256Bit> bspec256 = (ByteVector.ByteSpecies<Shapes.S256Bit>) Vector.speciesInstance(Byte.class, Shapes.S_256_BIT); 40 static final ShortVector.ShortSpecies<Shapes.S256Bit> sspec256 = (ShortVector.ShortSpecies<Shapes.S256Bit>) Vector.speciesInstance(Short.class, Shapes.S_256_BIT); 41 42 static final IntVector.IntSpecies<Shapes.S512Bit> ispec512 = (IntVector.IntSpecies<Shapes.S512Bit>) Vector.speciesInstance(Integer.class, Shapes.S_512_BIT); 43 static final FloatVector.FloatSpecies<Shapes.S512Bit> fspec512 = (FloatVector.FloatSpecies<Shapes.S512Bit>) Vector.speciesInstance(Float.class, Shapes.S_512_BIT); 44 static final LongVector.LongSpecies<Shapes.S512Bit> lspec512 = (LongVector.LongSpecies<Shapes.S512Bit>) Vector.speciesInstance(Long.class, Shapes.S_512_BIT); 45 static final DoubleVector.DoubleSpecies<Shapes.S512Bit> dspec512 = (DoubleVector.DoubleSpecies<Shapes.S512Bit>) Vector.speciesInstance(Double.class, Shapes.S_512_BIT); 46 static final ByteVector.ByteSpecies<Shapes.S512Bit> bspec512 = (ByteVector.ByteSpecies<Shapes.S512Bit>) Vector.speciesInstance(Byte.class, Shapes.S_512_BIT); 47 static final ShortVector.ShortSpecies<Shapes.S512Bit> sspec512 = (ShortVector.ShortSpecies<Shapes.S512Bit>) Vector.speciesInstance(Short.class, Shapes.S_512_BIT); 48 49 static <T> IntFunction<T> withToString(String s, IntFunction<T> f) { 50 return new IntFunction<T>() { 51 @Override 52 public T apply(int v) { 53 return f.apply(v); 54 } 55 56 @Override 57 public String toString() { 58 return s; 59 } 60 }; 61 } 62 63 interface ToByteF { 64 byte apply(int i); 65 } 66 67 static byte[] fill(int s , ToByteF f) { 68 return fill(new byte[s], f); 69 } 70 71 static byte[] fill(byte[] a, ToByteF f) { 72 for (int i = 0; i < a.length; i++) { 73 a[i] = f.apply(i); 74 } 75 return a; 76 } 77 78 static final List<IntFunction<byte[]>> BYTE_GENERATORS = List.of( 79 withToString("byte[i * 5]", (int s) -> { 80 return fill(s * 1000, 81 i -> (byte)(i * 5)); 82 }), 83 withToString("byte[i + 1]", (int s) -> { 84 return fill(s * 1000, 85 i -> (((byte)(i + 1) == 0) ? 1 : (byte)(i + 1))); 86 }) 87 ); 88 89 @DataProvider 90 public Object[][] byteUnaryOpProvider() { 91 return BYTE_GENERATORS.stream(). 92 map(f -> new Object[]{f}). 93 toArray(Object[][]::new); 94 } 95 96 @ForceInline 97 static void testResize(Vector.Species a, Vector.Species b, byte[] input) { 98 int spec_a_num_bytes = a.bitSize() / Byte.SIZE; 99 int spec_b_num_bytes = b.bitSize() / Byte.SIZE; 100 101 // Set a loop bound so that a load will not go out of bounds based on vector size. 102 int loop_bound = Math.min(input.length - spec_a_num_bytes, input.length - spec_b_num_bytes); 103 104 // Create arrays being used for storing. Use largest size. 105 int max_num_bytes = Math.max(spec_a_num_bytes, spec_b_num_bytes); 106 int min_num_bytes = Math.min(spec_a_num_bytes, spec_b_num_bytes); 107 byte[] actual = new byte[max_num_bytes]; 108 byte[] expected = new byte[max_num_bytes]; 109 110 for (int i = 0; i < loop_bound; i++) { 111 Vector av = a.fromByteArray(input, i); 112 Vector bv = av.resize(b); 113 bv.intoByteArray(actual, 0); 114 115 // Compute expected. 116 for (int j = 0; j < max_num_bytes; j++) { 117 if (j < min_num_bytes) { 118 expected[j] = input[i+j]; 119 } else { 120 expected[j] = 0; 121 } 122 } 123 124 Assert.assertEquals(expected, actual); 125 } 126 } 127 128 @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) 129 static void testResizeByte(IntFunction<byte[]> fa) { 130 byte[] barr = fa.apply(bspec64.elementSize()); 131 for (int i = 0; i < NUM_ITER; i++) { 132 testResize(bspec64, bspec64, barr); 133 testResize(bspec64, bspec128, barr); 134 testResize(bspec64, bspec256, barr); 135 testResize(bspec64, bspec512, barr); 136 testResize(bspec128, bspec64, barr); 137 testResize(bspec128, bspec128, barr); 138 testResize(bspec128, bspec256, barr); 139 testResize(bspec128, bspec512, barr); 140 testResize(bspec256, bspec64, barr); 141 testResize(bspec256, bspec128, barr); 142 testResize(bspec256, bspec256, barr); 143 testResize(bspec256, bspec512, barr); 144 testResize(bspec512, bspec64, barr); 145 testResize(bspec512, bspec128, barr); 146 testResize(bspec512, bspec256, barr); 147 testResize(bspec512, bspec512, barr); 148 } 149 } 150 151 @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) 152 static void testResizeShort(IntFunction<byte[]> fa) { 153 byte[] barr = fa.apply(sspec64.elementSize()); 154 for (int i = 0; i < NUM_ITER; i++) { 155 testResize(sspec64, sspec64, barr); 156 testResize(sspec64, sspec128, barr); 157 testResize(sspec64, sspec256, barr); 158 testResize(sspec64, sspec512, barr); 159 testResize(sspec128, sspec64, barr); 160 testResize(sspec128, sspec128, barr); 161 testResize(sspec128, sspec256, barr); 162 testResize(sspec128, sspec512, barr); 163 testResize(sspec256, sspec64, barr); 164 testResize(sspec256, sspec128, barr); 165 testResize(sspec256, sspec256, barr); 166 testResize(sspec256, sspec512, barr); 167 testResize(sspec512, sspec64, barr); 168 testResize(sspec512, sspec128, barr); 169 testResize(sspec512, sspec256, barr); 170 testResize(sspec512, sspec512, barr); 171 } 172 } 173 174 @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) 175 static void testResizeInt(IntFunction<byte[]> fa) { 176 byte[] barr = fa.apply(ispec64.elementSize()); 177 for (int i = 0; i < NUM_ITER; i++) { 178 testResize(ispec64, ispec64, barr); 179 testResize(ispec64, ispec128, barr); 180 testResize(ispec64, ispec256, barr); 181 testResize(ispec64, ispec512, barr); 182 testResize(ispec128, ispec64, barr); 183 testResize(ispec128, ispec128, barr); 184 testResize(ispec128, ispec256, barr); 185 testResize(ispec128, ispec512, barr); 186 testResize(ispec256, ispec64, barr); 187 testResize(ispec256, ispec128, barr); 188 testResize(ispec256, ispec256, barr); 189 testResize(ispec256, ispec512, barr); 190 testResize(ispec512, ispec64, barr); 191 testResize(ispec512, ispec128, barr); 192 testResize(ispec512, ispec256, barr); 193 testResize(ispec512, ispec512, barr); 194 } 195 } 196 197 @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) 198 static void testResizeLong(IntFunction<byte[]> fa) { 199 byte[] barr = fa.apply(lspec64.elementSize()); 200 for (int i = 0; i < NUM_ITER; i++) { 201 testResize(lspec64, lspec64, barr); 202 testResize(lspec64, lspec128, barr); 203 testResize(lspec64, lspec256, barr); 204 testResize(lspec64, lspec512, barr); 205 testResize(lspec128, lspec64, barr); 206 testResize(lspec128, lspec128, barr); 207 testResize(lspec128, lspec256, barr); 208 testResize(lspec128, lspec512, barr); 209 testResize(lspec256, lspec64, barr); 210 testResize(lspec256, lspec128, barr); 211 testResize(lspec256, lspec256, barr); 212 testResize(lspec256, lspec512, barr); 213 testResize(lspec512, lspec64, barr); 214 testResize(lspec512, lspec128, barr); 215 testResize(lspec512, lspec256, barr); 216 testResize(lspec512, lspec512, barr); 217 } 218 } 219 220 @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) 221 static void testResizeFloat(IntFunction<byte[]> fa) { 222 byte[] barr = fa.apply(fspec64.elementSize()); 223 for (int i = 0; i < NUM_ITER; i++) { 224 testResize(fspec64, fspec64, barr); 225 testResize(fspec64, fspec128, barr); 226 testResize(fspec64, fspec256, barr); 227 testResize(fspec64, fspec512, barr); 228 testResize(fspec128, fspec64, barr); 229 testResize(fspec128, fspec128, barr); 230 testResize(fspec128, fspec256, barr); 231 testResize(fspec128, fspec512, barr); 232 testResize(fspec256, fspec64, barr); 233 testResize(fspec256, fspec128, barr); 234 testResize(fspec256, fspec256, barr); 235 testResize(fspec256, fspec512, barr); 236 testResize(fspec512, fspec64, barr); 237 testResize(fspec512, fspec128, barr); 238 testResize(fspec512, fspec256, barr); 239 testResize(fspec512, fspec512, barr); 240 } 241 } 242 243 @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2) 244 static void testResizeDouble(IntFunction<byte[]> fa) { 245 byte[] barr = fa.apply(dspec64.elementSize()); 246 for (int i = 0; i < NUM_ITER; i++) { 247 testResize(dspec64, dspec64, barr); 248 testResize(dspec64, dspec128, barr); 249 testResize(dspec64, dspec256, barr); 250 testResize(dspec64, dspec512, barr); 251 testResize(dspec128, dspec64, barr); 252 testResize(dspec128, dspec128, barr); 253 testResize(dspec128, dspec256, barr); 254 testResize(dspec128, dspec512, barr); 255 testResize(dspec256, dspec64, barr); 256 testResize(dspec256, dspec128, barr); 257 testResize(dspec256, dspec256, barr); 258 testResize(dspec256, dspec512, barr); 259 testResize(dspec512, dspec64, barr); 260 testResize(dspec512, dspec128, barr); 261 testResize(dspec512, dspec256, barr); 262 testResize(dspec512, dspec512, barr); 263 } 264 } 265 }