1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 /*
  27  * @test
  28  * @run testng StdLibTest
  29  */
  30 
  31 import java.foreign.memory.*;
  32 import java.lang.invoke.MethodHandles;
  33 import java.foreign.Libraries;
  34 import java.foreign.NativeTypes;
  35 import java.foreign.Scope;
  36 import java.foreign.annotations.NativeCallback;
  37 import java.foreign.annotations.NativeHeader;
  38 import java.foreign.annotations.NativeStruct;
  39 import java.time.Instant;
  40 import java.time.LocalDateTime;
  41 import java.time.ZoneId;
  42 import java.time.ZoneOffset;
  43 import java.time.ZonedDateTime;
  44 import java.util.ArrayList;
  45 import java.util.Arrays;
  46 import java.util.List;
  47 import java.util.Set;
  48 import java.util.function.Function;
  49 import java.util.stream.Collectors;
  50 import java.util.stream.Stream;
  51 
  52 import org.testng.annotations.*;
  53 
  54 import static org.testng.Assert.*;
  55 
  56 @Test
  57 public class StdLibTest {
  58 
  59     private StdLibHelper stdLibHelper = new StdLibHelper();
  60 
  61     @Test(dataProvider = "stringPairs")
  62     void test_strcat(String s1, String s2) {
  63         assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);
  64     }
  65 
  66     @Test(dataProvider = "stringPairs")
  67     void test_strcmp(String s1, String s2) {
  68         assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));
  69     }
  70 
  71     @Test(dataProvider = "strings")
  72     void test_puts(String s) {
  73         assertTrue(stdLibHelper.puts(s) >= 0);
  74     }
  75 
  76     @Test(dataProvider = "strings")
  77     void test_strlen(String s) {
  78         assertEquals(stdLibHelper.strlen(s), s.length());
  79     }
  80 
  81     @Test(dataProvider = "instants")
  82     void test_time(Instant instant) {
  83         try (Scope s = Scope.newNativeScope()) {
  84             StdLibHelper.Time time = s.allocateStruct(StdLibHelper.Time.class);
  85             time.setSeconds(instant.getEpochSecond());
  86             //numbers should be in the same ballpark
  87             assertEquals(time.seconds(), instant.getEpochSecond());
  88             @SuppressWarnings("unchecked")
  89             StdLibHelper.Tm tm = stdLibHelper.gmtime(time.ptr()).get();
  90             LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneOffset.UTC);
  91             assertEquals(tm.sec(), localTime.getSecond());
  92             assertEquals(tm.min(), localTime.getMinute());
  93             assertEquals(tm.hour(), localTime.getHour());
  94             //day pf year in Java has 1-offset
  95             assertEquals(tm.yday(), localTime.getDayOfYear() - 1);
  96             assertEquals(tm.mday(), localTime.getDayOfMonth());
  97             //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset
  98             assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);
  99             //month in Java has 1-offset
 100             assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);
 101             assertEquals(tm.isdst(), ZoneOffset.UTC.getRules()
 102                     .isDaylightSavings(Instant.ofEpochMilli(time.seconds() * 1000)));
 103         }
 104     }
 105 
 106     @Test(dataProvider = "ints")
 107     void test_qsort(List<Integer> ints) {
 108         if (ints.size() > 0) {
 109             int[] input = ints.stream().mapToInt(i -> i).toArray();
 110             int[] sorted = stdLibHelper.qsort(input);
 111             Arrays.sort(input);
 112             assertEquals(sorted, input);
 113         }
 114     }
 115 
 116     @Test
 117     void test_rand() {
 118         int val = stdLibHelper.rand();
 119         for (int i = 0 ; i < 100 ; i++) {
 120             int newVal = stdLibHelper.rand();
 121             if (newVal != val) {
 122                 return; //ok
 123             }
 124             val = newVal;
 125         }
 126         fail("All values are the same! " + val);
 127     }
 128 
 129     @Test(dataProvider = "printfArgs")
 130     void test_printf(List<PrintfArg> args) {
 131         try (Scope s = Scope.newNativeScope()) {
 132             String formatStr = args.stream()
 133                     .map(a -> a.format)
 134                     .collect(Collectors.joining(","));
 135 
 136             int formattedArgsLength = args.stream()
 137                     .mapToInt(a -> a.length)
 138                     .sum();
 139 
 140             Object[] argValues = args.stream()
 141                     .map(a -> a.valueFunc.apply(s))
 142                     .toArray();
 143 
 144             int delimCount = (args.size() > 0) ?
 145                     args.size() - 1 :
 146                     0;
 147 
 148             assertEquals(stdLibHelper.printf("hello(" + formatStr + ")\n", argValues), 8 + formattedArgsLength + delimCount);
 149         }
 150     }
 151 
 152     static class StdLibHelper {
 153         StdLib stdLib = Libraries.bind(MethodHandles.lookup(), StdLib.class);
 154 
 155         String strcat(String s1, String s2) {
 156             try (Scope s = Scope.newNativeScope()) {
 157                 Pointer<Byte> buf = s.allocate(NativeTypes.INT8, s1.length() + s2.length() + 1);
 158                 Pointer<Byte> base = buf;
 159                 for (char c : s1.toCharArray()) {
 160                     buf.set((byte)c);
 161                     buf = buf.offset(1);
 162                 }
 163                 buf.set((byte)'\0');
 164                 return Pointer.toString(stdLib.strcat(base, s.allocateCString(s2)));
 165             }
 166         }
 167 
 168         int strcmp(String s1, String s2) {
 169             try (Scope s = Scope.newNativeScope()) {
 170                 return stdLib.strcmp(s.allocateCString(s1), s.allocateCString(s2));
 171             }
 172         }
 173 
 174         int puts(String msg) {
 175             try (Scope s = Scope.newNativeScope()) {
 176                 return stdLib.puts(s.allocateCString(msg));
 177             }
 178         }
 179 
 180         int strlen(String msg) {
 181             try (Scope s = Scope.newNativeScope()) {
 182                 return stdLib.strlen(s.allocateCString(msg));
 183             }
 184         }
 185 
 186         Pointer<Tm> gmtime(Pointer<Time> arg) {
 187             return stdLib.gmtime(arg);
 188         }
 189 
 190         int[] qsort(int[] array) {
 191             try (Scope s = Scope.newNativeScope()) {
 192                 //allocate the array
 193                 Array<Integer> arr = s.allocateArray(NativeTypes.INT32, array);
 194                 
 195                 //call the function
 196                 stdLib.qsort(arr.elementPointer(), array.length, 4,
 197                         s.allocateCallback((u1, u2) -> {
 198                                 int i1 = u1.get();
 199                                 int i2 = u2.get();
 200                                 return i1 - i2;
 201                         }));
 202                 //get result
 203                 return arr.toArray(int[]::new);
 204             }
 205         }
 206 
 207         int rand() {
 208             return stdLib.rand();
 209         }
 210 
 211         int printf(String format, Object... args) {
 212             try (Scope sc = Scope.newNativeScope()) {
 213                 return stdLib.printf(sc.allocateCString(format), args);
 214             }
 215         }
 216 
 217         Pointer<Void> fopen(String filename, String mode) {
 218             try (Scope s = Scope.newNativeScope()) {
 219                 return stdLib.fopen(s.allocateCString(filename), s.allocateCString(mode));
 220             }
 221         }
 222 
 223         @NativeHeader(declarations =
 224                 "puts=(u64:u8)i32" +
 225                 "strcat=(u64:u8u64:i8)u64:u8" +
 226                 "strcmp=(u64:u8u64:i8)i32" +
 227                 "strlen=(u64:u8)i32" +
 228                 "time=(u64:$(Time))$(Time)" +
 229                 "gmtime=(u64:$(Time))u64:$(Tm)" +
 230                 "qsort=(u64:[0i32]i32i32u64:(u64:i32u64:i32)i32)v" +
 231                 "rand=()i32" +
 232                 "printf=(u64:u8*)i32" +
 233                 "fopen=(u64:u8u64:i8)u64:v")
 234         public interface StdLib {
 235             int puts(Pointer<Byte> str);
 236             Pointer<Byte> strcat(Pointer<Byte> s1, Pointer<Byte> s2);
 237             int strcmp(Pointer<Byte> s1, Pointer<Byte> s2);
 238             int strlen(Pointer<Byte> s2);
 239             Time time(Pointer<Time> arg);
 240             Pointer<Tm> gmtime(Pointer<Time> arg);
 241             void qsort(Pointer<Integer> base, int nitems, int size, Callback<QsortComparator> comparator);
 242             int rand();
 243             int printf(Pointer<Byte> format, Object... args);
 244             Pointer<Void> fopen(Pointer<Byte> filename, Pointer<Byte> mode);
 245 
 246             @NativeCallback("(u64:i32u64:i32)i32")
 247             interface QsortComparator {
 248                 int compare(Pointer<Integer> u1, Pointer<Integer> u2);
 249             }
 250         }
 251 
 252         @NativeStruct("[" +
 253                 "   i64(get=seconds)(set=setSeconds)" +
 254                 "](Time)")
 255         public interface Time extends Struct<Time> {
 256             long seconds();
 257             void setSeconds(long secs);
 258         }
 259 
 260         @NativeStruct("[" +
 261                 "   i32(get=sec)" +
 262                 "   i32(get=min)" +
 263                 "   i32(get=hour)" +
 264                 "   i32(get=mday)" +
 265                 "   i32(get=mon)" +
 266                 "   i32(get=year)" +
 267                 "   i32(get=wday)" +
 268                 "   i32(get=yday)" +
 269                 "   i8(get=isdst)" +
 270                 "](Tm)")
 271         public interface Tm extends Struct<Tm> {
 272             int sec();
 273             int min();
 274             int hour();
 275             int mday();
 276             int mon();
 277             int year();
 278             int wday();
 279             int yday();
 280             boolean isdst();
 281         }
 282     }
 283 
 284     /*** data providers ***/
 285 
 286     @DataProvider
 287     public static Object[][] ints() {
 288         return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()
 289                 .map(l -> new Object[] { l })
 290                 .toArray(Object[][]::new);
 291     }
 292 
 293     @DataProvider
 294     public static Object[][] strings() {
 295         return perms(0, new String[] { "a", "b", "c" }).stream()
 296                 .map(l -> new Object[] { String.join("", l) })
 297                 .toArray(Object[][]::new);
 298     }
 299 
 300     @DataProvider
 301     public static Object[][] stringPairs() {
 302         Object[][] strings = strings();
 303         Object[][] stringPairs = new Object[strings.length * strings.length][];
 304         int pos = 0;
 305         for (Object[] s1 : strings) {
 306             for (Object[] s2 : strings) {
 307                 stringPairs[pos++] = new Object[] { s1[0], s2[0] };
 308             }
 309         }
 310         return stringPairs;
 311     }
 312 
 313     @DataProvider
 314     public static Object[][] instants() {
 315         Instant start = ZonedDateTime.of(LocalDateTime.parse("2017-01-01T00:00:00"), ZoneOffset.UTC).toInstant();
 316         Instant end = ZonedDateTime.of(LocalDateTime.parse("2017-12-31T00:00:00"), ZoneOffset.UTC).toInstant();
 317         Object[][] instants = new Object[100][];
 318         for (int i = 0 ; i < instants.length ; i++) {
 319             Instant instant = start.plusSeconds((long)(Math.random() * (end.getEpochSecond() - start.getEpochSecond())));
 320             instants[i] = new Object[] { instant };
 321         }
 322         return instants;
 323     }
 324 
 325     @DataProvider
 326     public static Object[][] printfArgs() {
 327         return perms(0, PrintfArg.values()).stream()
 328                 .map(l -> new Object[] { l })
 329                 .toArray(Object[][]::new);
 330     }
 331 
 332     enum PrintfArg {
 333         INTEGRAL("%d", s -> 42, 2),
 334         STRING("%s", s -> s.allocateCString("str"), 3),
 335         CHAR("%c", s -> 'h', 1),
 336         FLOAT("%.2f", s -> 1.23d, 4);
 337 
 338         String format;
 339         Function<Scope, Object> valueFunc;
 340         int length;
 341 
 342         PrintfArg(String format, Function<Scope, Object> valueFunc, int length) {
 343             this.format = format;
 344             this.valueFunc = valueFunc;
 345             this.length = length;
 346         }
 347     }
 348 
 349     static <Z> Set<List<Z>> perms(int count, Z[] arr) {
 350         if (count == arr.length) {
 351             return Set.of(List.of());
 352         } else {
 353             return Arrays.stream(arr)
 354                     .flatMap(num -> {
 355                         Set<List<Z>> perms = perms(count + 1, arr);
 356                         return Stream.concat(
 357                                 //take n
 358                                 perms.stream().map(l -> {
 359                                     List<Z> li = new ArrayList<>(l);
 360                                     li.add(num);
 361                                     return li;
 362                                 }),
 363                                 //drop n
 364                                 perms.stream());
 365                     }).collect(Collectors.toSet());
 366         }
 367     }
 368 }