1 import jdk.incubator.vector.*;
   2 import jdk.internal.vm.annotation.ForceInline;
   3 import org.testng.Assert;
   4 import org.testng.annotations.Test;
   5 import org.testng.annotations.DataProvider;
   6 
   7 import java.util.List;
   8 import java.util.function.IntFunction;
   9 
  10 /**
  11  * @test
  12  * @modules jdk.incubator.vector
  13  * @modules java.base/jdk.internal.vm.annotation
  14  * @run testng VectorReshapeTests
  15  */
  16  
  17 @Test
  18 public class VectorReshapeTests {
  19     static final int NUM_ITER = 20;
  20 
  21     static final IntVector.IntSpecies<Shapes.S64Bit> ispec64 = (IntVector.IntSpecies<Shapes.S64Bit>) Vector.speciesInstance(Integer.class, Shapes.S_64_BIT);
  22     static final FloatVector.FloatSpecies<Shapes.S64Bit> fspec64 = (FloatVector.FloatSpecies<Shapes.S64Bit>) Vector.speciesInstance(Float.class, Shapes.S_64_BIT);
  23     static final LongVector.LongSpecies<Shapes.S64Bit> lspec64 = (LongVector.LongSpecies<Shapes.S64Bit>) Vector.speciesInstance(Long.class, Shapes.S_64_BIT);
  24     static final DoubleVector.DoubleSpecies<Shapes.S64Bit> dspec64 = (DoubleVector.DoubleSpecies<Shapes.S64Bit>) Vector.speciesInstance(Double.class, Shapes.S_64_BIT);
  25     static final ByteVector.ByteSpecies<Shapes.S64Bit> bspec64 = (ByteVector.ByteSpecies<Shapes.S64Bit>) Vector.speciesInstance(Byte.class, Shapes.S_64_BIT);
  26     static final ShortVector.ShortSpecies<Shapes.S64Bit> sspec64 = (ShortVector.ShortSpecies<Shapes.S64Bit>) Vector.speciesInstance(Short.class, Shapes.S_64_BIT);
  27 
  28     static final IntVector.IntSpecies<Shapes.S128Bit> ispec128 = (IntVector.IntSpecies<Shapes.S128Bit>) Vector.speciesInstance(Integer.class, Shapes.S_128_BIT);
  29     static final FloatVector.FloatSpecies<Shapes.S128Bit> fspec128 = (FloatVector.FloatSpecies<Shapes.S128Bit>) Vector.speciesInstance(Float.class, Shapes.S_128_BIT);
  30     static final LongVector.LongSpecies<Shapes.S128Bit> lspec128 = (LongVector.LongSpecies<Shapes.S128Bit>) Vector.speciesInstance(Long.class, Shapes.S_128_BIT);
  31     static final DoubleVector.DoubleSpecies<Shapes.S128Bit> dspec128 = (DoubleVector.DoubleSpecies<Shapes.S128Bit>) Vector.speciesInstance(Double.class, Shapes.S_128_BIT);
  32     static final ByteVector.ByteSpecies<Shapes.S128Bit> bspec128 = (ByteVector.ByteSpecies<Shapes.S128Bit>) Vector.speciesInstance(Byte.class, Shapes.S_128_BIT);
  33     static final ShortVector.ShortSpecies<Shapes.S128Bit> sspec128 = (ShortVector.ShortSpecies<Shapes.S128Bit>) Vector.speciesInstance(Short.class, Shapes.S_128_BIT);
  34 
  35     static final IntVector.IntSpecies<Shapes.S256Bit> ispec256 = (IntVector.IntSpecies<Shapes.S256Bit>) Vector.speciesInstance(Integer.class, Shapes.S_256_BIT);
  36     static final FloatVector.FloatSpecies<Shapes.S256Bit> fspec256 = (FloatVector.FloatSpecies<Shapes.S256Bit>) Vector.speciesInstance(Float.class, Shapes.S_256_BIT);
  37     static final LongVector.LongSpecies<Shapes.S256Bit> lspec256 = (LongVector.LongSpecies<Shapes.S256Bit>) Vector.speciesInstance(Long.class, Shapes.S_256_BIT);
  38     static final DoubleVector.DoubleSpecies<Shapes.S256Bit> dspec256 = (DoubleVector.DoubleSpecies<Shapes.S256Bit>) Vector.speciesInstance(Double.class, Shapes.S_256_BIT);
  39     static final ByteVector.ByteSpecies<Shapes.S256Bit> bspec256 = (ByteVector.ByteSpecies<Shapes.S256Bit>) Vector.speciesInstance(Byte.class, Shapes.S_256_BIT);
  40     static final ShortVector.ShortSpecies<Shapes.S256Bit> sspec256 = (ShortVector.ShortSpecies<Shapes.S256Bit>) Vector.speciesInstance(Short.class, Shapes.S_256_BIT);
  41 
  42     static final IntVector.IntSpecies<Shapes.S512Bit> ispec512 = (IntVector.IntSpecies<Shapes.S512Bit>) Vector.speciesInstance(Integer.class, Shapes.S_512_BIT);
  43     static final FloatVector.FloatSpecies<Shapes.S512Bit> fspec512 = (FloatVector.FloatSpecies<Shapes.S512Bit>) Vector.speciesInstance(Float.class, Shapes.S_512_BIT);
  44     static final LongVector.LongSpecies<Shapes.S512Bit> lspec512 = (LongVector.LongSpecies<Shapes.S512Bit>) Vector.speciesInstance(Long.class, Shapes.S_512_BIT);
  45     static final DoubleVector.DoubleSpecies<Shapes.S512Bit> dspec512 = (DoubleVector.DoubleSpecies<Shapes.S512Bit>) Vector.speciesInstance(Double.class, Shapes.S_512_BIT);
  46     static final ByteVector.ByteSpecies<Shapes.S512Bit> bspec512 = (ByteVector.ByteSpecies<Shapes.S512Bit>) Vector.speciesInstance(Byte.class, Shapes.S_512_BIT);
  47     static final ShortVector.ShortSpecies<Shapes.S512Bit> sspec512 = (ShortVector.ShortSpecies<Shapes.S512Bit>) Vector.speciesInstance(Short.class, Shapes.S_512_BIT);
  48 
  49     static <T> IntFunction<T> withToString(String s, IntFunction<T> f) {
  50         return new IntFunction<T>() {
  51             @Override
  52             public T apply(int v) {
  53                 return f.apply(v);
  54             }
  55 
  56             @Override
  57             public String toString() {
  58                 return s;
  59             }
  60         };
  61     }
  62 
  63     interface ToByteF {
  64         byte apply(int i);
  65     }
  66 
  67     static byte[] fill(int s , ToByteF f) {
  68         return fill(new byte[s], f);
  69     }
  70 
  71     static byte[] fill(byte[] a, ToByteF f) {
  72         for (int i = 0; i < a.length; i++) {
  73             a[i] = f.apply(i);
  74         }
  75         return a;
  76     }
  77 
  78     static final List<IntFunction<byte[]>> BYTE_GENERATORS = List.of(
  79             withToString("byte[i * 5]", (int s) -> {
  80                 return fill(s * 1000,
  81                         i -> (byte)(i * 5));
  82             }),
  83             withToString("byte[i + 1]", (int s) -> {
  84                 return fill(s * 1000,
  85                         i -> (((byte)(i + 1) == 0) ? 1 : (byte)(i + 1)));
  86             })
  87     );
  88 
  89     @DataProvider
  90     public Object[][] byteUnaryOpProvider() {
  91         return BYTE_GENERATORS.stream().
  92                 map(f -> new Object[]{f}).
  93                 toArray(Object[][]::new);
  94     }
  95 
  96     @ForceInline
  97     static void testResize(Vector.Species a, Vector.Species b, byte[] input) {
  98         int spec_a_num_bytes = a.bitSize() / Byte.SIZE;
  99         int spec_b_num_bytes = b.bitSize() / Byte.SIZE;
 100 
 101         // Set a loop bound so that a load will not go out of bounds based on vector size.
 102         int loop_bound = Math.min(input.length - spec_a_num_bytes, input.length - spec_b_num_bytes);
 103 
 104         // Create arrays being used for storing. Use largest size.
 105         int max_num_bytes = Math.max(spec_a_num_bytes, spec_b_num_bytes);
 106         int min_num_bytes = Math.min(spec_a_num_bytes, spec_b_num_bytes);
 107         byte[] actual = new byte[max_num_bytes];
 108         byte[] expected = new byte[max_num_bytes];
 109 
 110         for (int i = 0; i < loop_bound; i++) {
 111             Vector av = a.fromByteArray(input, i);
 112             Vector bv = av.resize(b);
 113             bv.intoByteArray(actual, 0);
 114 
 115             // Compute expected.
 116             for (int j = 0; j < max_num_bytes; j++) {
 117                 if (j < min_num_bytes) {
 118                     expected[j] = input[i+j];
 119                 } else {
 120                     expected[j] = 0;
 121                 }
 122             }
 123 
 124             Assert.assertEquals(expected, actual);
 125         }
 126     }
 127 
 128     @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2)
 129     static void testResizeByte(IntFunction<byte[]> fa) {
 130         byte[] barr = fa.apply(bspec64.elementSize());
 131         for (int i = 0; i < NUM_ITER; i++) {
 132             testResize(bspec64, bspec64, barr);
 133             testResize(bspec64, bspec128, barr);
 134             testResize(bspec64, bspec256, barr);
 135             testResize(bspec64, bspec512, barr);
 136             testResize(bspec128, bspec64, barr);
 137             testResize(bspec128, bspec128, barr);
 138             testResize(bspec128, bspec256, barr);
 139             testResize(bspec128, bspec512, barr);
 140             testResize(bspec256, bspec64, barr);
 141             testResize(bspec256, bspec128, barr);
 142             testResize(bspec256, bspec256, barr);
 143             testResize(bspec256, bspec512, barr);
 144             testResize(bspec512, bspec64, barr);
 145             testResize(bspec512, bspec128, barr);
 146             testResize(bspec512, bspec256, barr);
 147             testResize(bspec512, bspec512, barr);
 148         }
 149     }
 150 
 151     @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2)
 152     static void testResizeShort(IntFunction<byte[]> fa) {
 153         byte[] barr = fa.apply(sspec64.elementSize());
 154         for (int i = 0; i < NUM_ITER; i++) {
 155             testResize(sspec64, sspec64, barr);
 156             testResize(sspec64, sspec128, barr);
 157             testResize(sspec64, sspec256, barr);
 158             testResize(sspec64, sspec512, barr);
 159             testResize(sspec128, sspec64, barr);
 160             testResize(sspec128, sspec128, barr);
 161             testResize(sspec128, sspec256, barr);
 162             testResize(sspec128, sspec512, barr);
 163             testResize(sspec256, sspec64, barr);
 164             testResize(sspec256, sspec128, barr);
 165             testResize(sspec256, sspec256, barr);
 166             testResize(sspec256, sspec512, barr);
 167             testResize(sspec512, sspec64, barr);
 168             testResize(sspec512, sspec128, barr);
 169             testResize(sspec512, sspec256, barr);
 170             testResize(sspec512, sspec512, barr);
 171         }
 172     }
 173 
 174     @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2)
 175     static void testResizeInt(IntFunction<byte[]> fa) {
 176         byte[] barr = fa.apply(ispec64.elementSize());
 177         for (int i = 0; i < NUM_ITER; i++) {
 178             testResize(ispec64, ispec64, barr);
 179             testResize(ispec64, ispec128, barr);
 180             testResize(ispec64, ispec256, barr);
 181             testResize(ispec64, ispec512, barr);
 182             testResize(ispec128, ispec64, barr);
 183             testResize(ispec128, ispec128, barr);
 184             testResize(ispec128, ispec256, barr);
 185             testResize(ispec128, ispec512, barr);
 186             testResize(ispec256, ispec64, barr);
 187             testResize(ispec256, ispec128, barr);
 188             testResize(ispec256, ispec256, barr);
 189             testResize(ispec256, ispec512, barr);
 190             testResize(ispec512, ispec64, barr);
 191             testResize(ispec512, ispec128, barr);
 192             testResize(ispec512, ispec256, barr);
 193             testResize(ispec512, ispec512, barr);
 194         }
 195     }
 196 
 197     @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2)
 198     static void testResizeLong(IntFunction<byte[]> fa) {
 199         byte[] barr = fa.apply(lspec64.elementSize());
 200         for (int i = 0; i < NUM_ITER; i++) {
 201             testResize(lspec64, lspec64, barr);
 202             testResize(lspec64, lspec128, barr);
 203             testResize(lspec64, lspec256, barr);
 204             testResize(lspec64, lspec512, barr);
 205             testResize(lspec128, lspec64, barr);
 206             testResize(lspec128, lspec128, barr);
 207             testResize(lspec128, lspec256, barr);
 208             testResize(lspec128, lspec512, barr);
 209             testResize(lspec256, lspec64, barr);
 210             testResize(lspec256, lspec128, barr);
 211             testResize(lspec256, lspec256, barr);
 212             testResize(lspec256, lspec512, barr);
 213             testResize(lspec512, lspec64, barr);
 214             testResize(lspec512, lspec128, barr);
 215             testResize(lspec512, lspec256, barr);
 216             testResize(lspec512, lspec512, barr);
 217         }
 218     }
 219 
 220     @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2)
 221     static void testResizeFloat(IntFunction<byte[]> fa) {
 222         byte[] barr = fa.apply(fspec64.elementSize());
 223         for (int i = 0; i < NUM_ITER; i++) {
 224             testResize(fspec64, fspec64, barr);
 225             testResize(fspec64, fspec128, barr);
 226             testResize(fspec64, fspec256, barr);
 227             testResize(fspec64, fspec512, barr);
 228             testResize(fspec128, fspec64, barr);
 229             testResize(fspec128, fspec128, barr);
 230             testResize(fspec128, fspec256, barr);
 231             testResize(fspec128, fspec512, barr);
 232             testResize(fspec256, fspec64, barr);
 233             testResize(fspec256, fspec128, barr);
 234             testResize(fspec256, fspec256, barr);
 235             testResize(fspec256, fspec512, barr);
 236             testResize(fspec512, fspec64, barr);
 237             testResize(fspec512, fspec128, barr);
 238             testResize(fspec512, fspec256, barr);
 239             testResize(fspec512, fspec512, barr);
 240         }
 241     }
 242 
 243     @Test(dataProvider = "byteUnaryOpProvider", invocationCount = 2)
 244     static void testResizeDouble(IntFunction<byte[]> fa) {
 245         byte[] barr = fa.apply(dspec64.elementSize());
 246         for (int i = 0; i < NUM_ITER; i++) {
 247             testResize(dspec64, dspec64, barr);
 248             testResize(dspec64, dspec128, barr);
 249             testResize(dspec64, dspec256, barr);
 250             testResize(dspec64, dspec512, barr);
 251             testResize(dspec128, dspec64, barr);
 252             testResize(dspec128, dspec128, barr);
 253             testResize(dspec128, dspec256, barr);
 254             testResize(dspec128, dspec512, barr);
 255             testResize(dspec256, dspec64, barr);
 256             testResize(dspec256, dspec128, barr);
 257             testResize(dspec256, dspec256, barr);
 258             testResize(dspec256, dspec512, barr);
 259             testResize(dspec512, dspec64, barr);
 260             testResize(dspec512, dspec128, barr);
 261             testResize(dspec512, dspec256, barr);
 262             testResize(dspec512, dspec512, barr);
 263         }
 264     }
 265 }