1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 import jdk.experimental.bytecode.MacroCodeBuilder.CondKind;
  27 import jdk.experimental.bytecode.TypeTag;
  28 import jdk.experimental.value.MethodHandleBuilder;
  29 import jdk.incubator.mvt.ValueType;
  30 
  31 import java.lang.invoke.MethodHandle;
  32 import java.lang.invoke.MethodHandles;
  33 import java.lang.invoke.MethodType;
  34 import java.util.Random;
  35 import java.util.stream.Stream;
  36 
  37 import org.testng.annotations.*;
  38 import static org.testng.Assert.assertEquals;
  39 import static org.testng.Assert.assertTrue;
  40 
  41 /*
  42  * @test
  43  * @modules java.base/jdk.experimental.bytecode
  44  *          java.base/jdk.experimental.value
  45  *          jdk.incubator.mvt
  46  * @run testng/othervm -XX:+EnableMVT TestPoint
  47  */
  48 
  49 @Test
  50 public class TestPoint {
  51 
  52     @DataProvider(name = "createPoints")
  53     public Object[][] createPoints() {
  54         Object[][] data = new Object[10 * 10][];
  55         int n = 0;
  56         for (int i = 0 ; i < 10 ; i++) {
  57             for (int j = 0 ; j < 10 ; j++) {
  58                 data[n++] = new Object[] { new Point(i, j) };
  59             }
  60         }
  61         return data;
  62     }
  63 
  64     @DataProvider(name = "createPointArrays")
  65     public Object[][] createPointArrays() throws Throwable {
  66         ValueType vt = ValueType.forClass(Point.class);
  67         Object[][] data = new Object[10][];
  68         Random rand = new Random();
  69         for (int n = 0 ; n < 10 ; n++) {
  70             int length = rand.nextInt(10);
  71             Point[] boxes = new Point[length];
  72             Object arr = MethodHandles.arrayConstructor(vt.arrayValueClass()).invoke(length);
  73             for (int i = 0; i < length; i++) {
  74                  Point p = new Point(i, i);
  75                  boxes[i] = p;
  76                  MethodHandles.arrayElementSetter(vt.arrayValueClass()).invoke(arr, i, p);
  77             }
  78             data[n] = new Object[] { boxes, arr };
  79         }
  80         return data;
  81     }
  82 
  83     @Test(dataProvider = "createPoints")
  84     public void testMakePoint(Point p) throws Throwable {
  85         assertEquals(p, makePoint().invoke(p.x, p.y));
  86     }
  87 
  88     static MethodHandle makePoint() throws Throwable {
  89         ValueType vt = ValueType.forClass(Point.class);
  90         MethodHandles.Lookup lookup = MethodHandles.lookup();
  91         lookup = MethodHandles.privateLookupIn(Point.class, lookup);
  92         MethodHandle impl = vt.findWither(lookup, "y", int.class); // (QPoint,int)->QPoint
  93         impl = MethodHandles.collectArguments(impl, 0, vt.findWither(lookup, "x", int.class)); // (QPoint,int,int)->QPoint
  94         impl = MethodHandles.collectArguments(impl, 0, vt.defaultValueConstant()); // (int,int)->QPoint
  95         return impl;
  96     }
  97 
  98     @Test(dataProvider = "createPoints")
  99     public void testMakePoint_bytecode(Point p) throws Throwable {
 100         assertEquals(p, makePoint_bytecode().invoke(p.x, p.y));
 101     }
 102 
 103     static MethodHandle makePoint_bytecode() throws Throwable {
 104         ValueType vt = ValueType.forClass(Point.class);
 105         MethodHandle mh = MethodHandleBuilder.loadCode(MethodHandles.privateLookupIn(vt.valueClass(), MethodHandles.lookup()),
 106                 "makePoint", MethodType.methodType(vt.valueClass(), int.class, int.class),
 107                 C -> {
 108                     C.vdefault(vt.valueClass())
 109                             .iload_0()
 110                             .vwithfield(vt.valueClass(), "x", "I")
 111                             .iload_1()
 112                             .vwithfield(vt.valueClass(), "y", "I")
 113                             .vreturn();
 114                 });
 115         return mh;
 116     }
 117 
 118     @Test(dataProvider = "createPoints")
 119     public void testNorm(Point p) throws Throwable {
 120         assertEquals(p.norm(), norm().invoke(p));
 121     }
 122 
 123     static MethodHandle norm() throws Throwable {
 124         ValueType vt = ValueType.forClass(Point.class);
 125         MethodHandles.Lookup lookup = MethodHandles.lookup();
 126         MethodHandle id = vt.identity();
 127         MethodHandle norm = lookup.findVirtual(Point.class, "norm", MethodType.methodType(double.class));
 128         MethodHandle impl = MethodHandles.filterReturnValue(id, vt.box());
 129         impl = MethodHandles.filterReturnValue(impl, norm);
 130         return impl;
 131     }
 132 
 133     @Test(dataProvider = "createPoints")
 134     public void testNorm_bytecode(Point p) throws Throwable {
 135         assertEquals(p.norm(), norm_bytecode().invoke(p));
 136     }
 137 
 138     static MethodHandle norm_bytecode() throws Throwable {
 139         ValueType vt = ValueType.forClass(Point.class);
 140         MethodHandle mh = MethodHandleBuilder.loadCode(MethodHandles.lookup(), "norm", MethodType.methodType(double.class, vt.valueClass()),
 141                 C -> {
 142                     C.vload(0)
 143                             .vbox(Point.class)
 144                             .invokevirtual(Point.class, "norm", "()D", false)
 145                             .dreturn();
 146                 });
 147         return mh;
 148 
 149     }
 150 
 151     @Test(dataProvider = "createPointArrays")
 152     public void testTotalNorm(Point[] boxes, Object arr) throws Throwable {
 153         assertTrue(Math.abs(totalNorm(boxes) - (double)totalNorm().invoke(arr)) < 0.0001d);
 154     }
 155 
 156     private double totalNorm(Point[] parr) {
 157         return Stream.of(parr).mapToDouble(Point::norm).sum();
 158     }
 159 
 160     private static MethodHandle totalNorm() throws Throwable {
 161         ValueType vt = ValueType.forClass(Point.class);
 162         MethodHandles.Lookup lookup = MethodHandles.lookup();
 163         MethodType loopLocals = MethodType.methodType(double.class, vt.arrayValueClass());
 164         MethodType loopParams = loopLocals.insertParameterTypes(0, double.class, int.class); // (int,double,T)->double
 165         MethodHandle init = MethodHandles.permuteArguments(MethodHandles.constant(double.class, 0.0), loopLocals); // (T)->double
 166         MethodHandle body = lookup.findStatic(Double.class, "sum", MethodType.methodType(double.class, double.class, double.class)); // (double,double)->double
 167         body = MethodHandles.collectArguments(body, 1, norm()); // (double,QPoint)->double
 168         body = MethodHandles.collectArguments(body, 1, vt.arrayGetter()); // (double,T,int)->double
 169         body = MethodHandles.permuteArguments(body, loopParams, 0, 2, 1); // (int,double,T)->double
 170         return MethodHandles.countedLoop(vt.arrayLength(), init, body);
 171     }
 172 
 173     @Test(dataProvider = "createPointArrays")
 174     public void testTotalNorm_bytecode(Point[] boxes, Object arr) throws Throwable {
 175         assertTrue(Math.abs(totalNorm(boxes) - (double)totalNorm_bytecode().invoke(arr)) < 0.0001d);
 176     }
 177 
 178     static MethodHandle totalNorm_bytecode() throws Throwable {
 179         ValueType vt = ValueType.forClass(Point.class);
 180         MethodHandle mh = MethodHandleBuilder.loadCode(MethodHandles.privateLookupIn(vt.valueClass(), MethodHandles.lookup()), "totalNorm", MethodType.methodType(double.class, vt.arrayValueClass()),
 181                 C -> {
 182                     C.iconst_0()
 183                             .istore_1()
 184                             .dconst_0()
 185                             .dstore_2()
 186                             .label("loop")
 187                             .aload_0()
 188                             .arraylength()
 189                             .iload_1()
 190                             .ifcmp(TypeTag.I, CondKind.LE, "end")
 191                             .aload_0()
 192                             .iload_1()
 193                             .vaload()
 194                             .vbox(Point.class)
 195                             .invokevirtual(Point.class, "norm", "()D", false)
 196                             .dload_2()
 197                             .dadd()
 198                             .dstore_2()
 199                             .iinc(1, 1)
 200                             .goto_("loop")
 201                             .label("end")
 202                             .dload_2()
 203                             .dreturn();
 204                 });
 205         return mh;
 206     }
 207 
 208     public void guardWithTest() throws Throwable {
 209         MethodHandle point1 = MethodHandles.insertArguments(makePoint(), 0, 1, 1);
 210         MethodHandle point2 = MethodHandles.insertArguments(makePoint(), 0, 2, 2);
 211         MethodHandle predicate_T = MethodHandles.constant(boolean.class, true);
 212         MethodHandle predicate_F = MethodHandles.constant(boolean.class, false);
 213         MethodHandle gwt_T = MethodHandles.guardWithTest(predicate_T, point1, point2);
 214         MethodHandle gwt_F = MethodHandles.guardWithTest(predicate_F, point1, point2);
 215         assertEquals(gwt_T.invoke(), new Point(1, 1));
 216         assertEquals(gwt_F.invoke(), new Point(2, 2));
 217     }
 218 }