1 /*
   2  * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 import jdk.experimental.bytecode.MacroCodeBuilder.CondKind;
  27 import jdk.experimental.bytecode.TypeTag;
  28 import jdk.experimental.value.MethodHandleBuilder;
  29 import jdk.experimental.value.ValueType;
  30 
  31 import java.lang.invoke.MethodHandle;
  32 import java.lang.invoke.MethodHandles;
  33 import java.lang.invoke.MethodType;
  34 import java.util.Random;
  35 import java.util.stream.Stream;
  36 
  37 import org.testng.annotations.*;
  38 import static org.testng.Assert.assertEquals;
  39 import static org.testng.Assert.assertTrue;
  40 
  41 /*
  42  * @test
  43  * @run testng/othervm -Xverify:none -XX:+EnableMVT TestPoint
  44  */
  45 
  46 @Test
  47 public class TestPoint {
  48 
  49     @DataProvider(name = "createPoints")
  50     public Object[][] createPoints() {
  51         Object[][] data = new Object[10 * 10][];
  52         int n = 0;
  53         for (int i = 0 ; i < 10 ; i++) {
  54             for (int j = 0 ; j < 10 ; j++) {
  55                 data[n++] = new Object[] { new Point(i, j) };
  56             }
  57         }
  58         return data;
  59     }
  60 
  61     @DataProvider(name = "createPointArrays")
  62     public Object[][] createPointArrays() throws Throwable {
  63         ValueType vt = ValueType.forClass(Point.class);
  64         Object[][] data = new Object[10][];
  65         Random rand = new Random();
  66         for (int n = 0 ; n < 10 ; n++) {
  67             int length = rand.nextInt(10);
  68             Point[] boxes = new Point[length];
  69             Object arr = MethodHandles.arrayConstructor(vt.arrayValueClass()).invoke(length);
  70             for (int i = 0; i < length; i++) {
  71                  Point p = new Point(i, i);
  72                  boxes[i] = p;
  73                  MethodHandles.arrayElementSetter(vt.arrayValueClass()).invoke(arr, i, p);
  74             }
  75             data[n] = new Object[] { boxes, arr };
  76         }
  77         return data;
  78     }
  79 
  80     @Test(dataProvider = "createPoints")
  81     public void testMakePoint(Point p) throws Throwable {
  82         assertEquals(p, makePoint().invoke(p.x, p.y));
  83     }
  84 
  85     static MethodHandle makePoint() throws Throwable {
  86         ValueType vt = ValueType.forClass(Point.class);
  87         MethodHandles.Lookup lookup = MethodHandles.lookup();
  88         lookup = MethodHandles.privateLookupIn(Point.class, lookup);
  89         MethodHandle impl = vt.findWither(lookup, "y", int.class); // (QPoint,int)->QPoint
  90         impl = MethodHandles.collectArguments(impl, 0, vt.findWither(lookup, "x", int.class)); // (QPoint,int,int)->QPoint
  91         impl = MethodHandles.collectArguments(impl, 0, vt.defaultValueConstant()); // (int,int)->QPoint
  92         return impl;
  93     }
  94 
  95     @Test(dataProvider = "createPoints")
  96     public void testMakePoint_bytecode(Point p) throws Throwable {
  97         assertEquals(p, makePoint_bytecode().invoke(p.x, p.y));
  98     }
  99 
 100     static MethodHandle makePoint_bytecode() throws Throwable {
 101         ValueType vt = ValueType.forClass(Point.class);
 102         MethodHandle mh = MethodHandleBuilder.loadCode(MethodHandles.privateLookupIn(vt.valueClass(), MethodHandles.lookup()),
 103                 "makePoint", MethodType.methodType(vt.valueClass(), int.class, int.class),
 104                 C -> {
 105                     C.vdefault(vt.valueClass())
 106                             .iload_0()
 107                             .vwithfield(vt.valueClass(), "x", "I")
 108                             .iload_1()
 109                             .vwithfield(vt.valueClass(), "y", "I")
 110                             .vreturn();
 111                 });
 112         return mh;
 113     }
 114 
 115     @Test(dataProvider = "createPoints")
 116     public void testNorm(Point p) throws Throwable {
 117         assertEquals(p.norm(), norm().invoke(p));
 118     }
 119 
 120     static MethodHandle norm() throws Throwable {
 121         ValueType vt = ValueType.forClass(Point.class);
 122         MethodHandles.Lookup lookup = MethodHandles.lookup();
 123         MethodHandle id = vt.identity();
 124         MethodHandle norm = lookup.findVirtual(Point.class, "norm", MethodType.methodType(double.class));
 125         MethodHandle impl = MethodHandles.filterReturnValue(id, vt.box());
 126         impl = MethodHandles.filterReturnValue(impl, norm);
 127         return impl;
 128     }
 129 
 130     @Test(dataProvider = "createPoints")
 131     public void testNorm_bytecode(Point p) throws Throwable {
 132         assertEquals(p.norm(), norm_bytecode().invoke(p));
 133     }
 134 
 135     static MethodHandle norm_bytecode() throws Throwable {
 136         ValueType vt = ValueType.forClass(Point.class);
 137         MethodHandle mh = MethodHandleBuilder.loadCode(MethodHandles.lookup(), "norm", MethodType.methodType(double.class, vt.valueClass()),
 138                 C -> {
 139                     C.vload(0)
 140                             .vbox(Point.class)
 141                             .invokevirtual(Point.class, "norm", "()D", false)
 142                             .dreturn();
 143                 });
 144         return mh;
 145 
 146     }
 147 
 148     @Test(dataProvider = "createPointArrays")
 149     public void testTotalNorm(Point[] boxes, Object arr) throws Throwable {
 150         assertTrue(Math.abs(totalNorm(boxes) - (double)totalNorm().invoke(arr)) < 0.0001d);
 151     }
 152 
 153     private double totalNorm(Point[] parr) {
 154         return Stream.of(parr).mapToDouble(Point::norm).sum();
 155     }
 156 
 157     private static MethodHandle totalNorm() throws Throwable {
 158         ValueType vt = ValueType.forClass(Point.class);
 159         MethodHandles.Lookup lookup = MethodHandles.lookup();
 160         MethodType loopLocals = MethodType.methodType(double.class, vt.arrayValueClass());
 161         MethodType loopParams = loopLocals.insertParameterTypes(0, double.class, int.class); // (int,double,T)->double
 162         MethodHandle init = MethodHandles.permuteArguments(MethodHandles.constant(double.class, 0.0), loopLocals); // (T)->double
 163         MethodHandle body = lookup.findStatic(Double.class, "sum", MethodType.methodType(double.class, double.class, double.class)); // (double,double)->double
 164         body = MethodHandles.collectArguments(body, 1, norm()); // (double,QPoint)->double
 165         body = MethodHandles.collectArguments(body, 1, vt.arrayGetter()); // (double,T,int)->double
 166         body = MethodHandles.permuteArguments(body, loopParams, 0, 2, 1); // (int,double,T)->double
 167         return MethodHandles.countedLoop(vt.arrayLength(), init, body);
 168     }
 169 
 170     @Test(dataProvider = "createPointArrays")
 171     public void testTotalNorm_bytecode(Point[] boxes, Object arr) throws Throwable {
 172         assertTrue(Math.abs(totalNorm(boxes) - (double)totalNorm_bytecode().invoke(arr)) < 0.0001d);
 173     }
 174 
 175     static MethodHandle totalNorm_bytecode() throws Throwable {
 176         ValueType vt = ValueType.forClass(Point.class);
 177         MethodHandle mh = MethodHandleBuilder.loadCode(MethodHandles.privateLookupIn(vt.valueClass(), MethodHandles.lookup()), "totalNorm", MethodType.methodType(double.class, vt.arrayValueClass()),
 178                 C -> {
 179                     C.iconst_0()
 180                             .istore_1()
 181                             .dconst_0()
 182                             .dstore_2()
 183                             .label("loop")
 184                             .aload_0()
 185                             .arraylength()
 186                             .iload_1()
 187                             .ifcmp(TypeTag.I, CondKind.LE, "end")
 188                             .aload_0()
 189                             .iload_1()
 190                             .vaload()
 191                             .vbox(Point.class)
 192                             .invokevirtual(Point.class, "norm", "()D", false)
 193                             .dload_2()
 194                             .dadd()
 195                             .dstore_2()
 196                             .iinc(1, 1)
 197                             .goto_("loop")
 198                             .label("end")
 199                             .dload_2()
 200                             .dreturn();
 201                 });
 202         return mh;
 203     }
 204 
 205     public void guardWithTest() throws Throwable {
 206         MethodHandle point1 = MethodHandles.insertArguments(makePoint(), 0, 1, 1);
 207         MethodHandle point2 = MethodHandles.insertArguments(makePoint(), 0, 2, 2);
 208         MethodHandle predicate_T = MethodHandles.constant(boolean.class, true);
 209         MethodHandle predicate_F = MethodHandles.constant(boolean.class, false);
 210         MethodHandle gwt_T = MethodHandles.guardWithTest(predicate_T, point1, point2);
 211         MethodHandle gwt_F = MethodHandles.guardWithTest(predicate_F, point1, point2);
 212         assertEquals(gwt_T.invoke(), new Point(1, 1));
 213         assertEquals(gwt_F.invoke(), new Point(2, 2));
 214     }
 215 }