1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 /*
  27  * @test
  28  * @run testng StdLibTest
  29  */
  30 
  31 import java.foreign.memory.*;
  32 import java.lang.invoke.MethodHandles;
  33 import java.foreign.Libraries;
  34 import java.foreign.NativeTypes;
  35 import java.foreign.Scope;
  36 import java.foreign.annotations.NativeCallback;
  37 import java.foreign.annotations.NativeHeader;
  38 import java.foreign.annotations.NativeStruct;
  39 import java.sql.Timestamp;
  40 import java.time.Instant;
  41 import java.time.LocalDateTime;
  42 import java.time.ZoneId;
  43 import java.util.ArrayList;
  44 import java.util.Arrays;
  45 import java.util.List;
  46 import java.util.Set;
  47 import java.util.function.Function;
  48 import java.util.stream.Collectors;
  49 import java.util.stream.Stream;
  50 
  51 import org.testng.annotations.*;
  52 
  53 import static org.testng.Assert.*;
  54 
  55 @Test
  56 public class StdLibTest {
  57 
  58     private StdLibHelper stdLibHelper = new StdLibHelper();
  59 
  60     @Test(dataProvider = "stringPairs")
  61     void test_strcat(String s1, String s2) {
  62         assertEquals(stdLibHelper.strcat(s1, s2), s1 + s2);
  63     }
  64 
  65     @Test(dataProvider = "stringPairs")
  66     void test_strcmp(String s1, String s2) {
  67         assertEquals(Math.signum(stdLibHelper.strcmp(s1, s2)), Math.signum(s1.compareTo(s2)));
  68     }
  69 
  70     @Test(dataProvider = "strings")
  71     void test_puts(String s) {
  72         assertTrue(stdLibHelper.puts(s) > 0);
  73     }
  74 
  75     @Test(dataProvider = "strings")
  76     void test_strlen(String s) {
  77         assertEquals(stdLibHelper.strlen(s), s.length());
  78     }
  79 
  80     @Test(dataProvider = "instants")
  81     void test_time(Instant instant) {
  82         try (Scope s = Scope.newNativeScope()) {
  83             StdLibHelper.Time time = s.allocateStruct(StdLibHelper.Time.class);
  84             time.setSeconds(instant.getEpochSecond());
  85             //numbers should be in the same ballpark
  86             assertEquals(time.seconds(), instant.getEpochSecond());
  87             @SuppressWarnings("unchecked")
  88             StdLibHelper.Tm tm = stdLibHelper.localtime(time.ptr()).get();
  89             LocalDateTime localTime = LocalDateTime.ofInstant(instant, ZoneId.systemDefault());
  90             assertEquals(tm.sec(), localTime.getSecond());
  91             assertEquals(tm.min(), localTime.getMinute());
  92             assertEquals(tm.hour(), localTime.getHour());
  93             //day pf year in Java has 1-offset
  94             assertEquals(tm.yday(), localTime.getDayOfYear() - 1);
  95             assertEquals(tm.mday(), localTime.getDayOfMonth());
  96             //days of week starts from Sunday in C, but on Monday in Java, also account for 1-offset
  97             assertEquals((tm.wday() + 6) % 7, localTime.getDayOfWeek().getValue() - 1);
  98             //month in Java has 1-offset
  99             assertEquals(tm.mon(), localTime.getMonth().getValue() - 1);
 100             assertEquals(tm.isdst(), ZoneId.systemDefault().getRules()
 101                     .isDaylightSavings(Instant.ofEpochMilli(time.seconds() * 1000)));
 102         }
 103     }
 104 
 105     @Test(dataProvider = "ints")
 106     void test_qsort(List<Integer> ints) {
 107         if (ints.size() > 0) {
 108             int[] input = ints.stream().mapToInt(i -> i).toArray();
 109             int[] sorted = stdLibHelper.qsort(input);
 110             Arrays.sort(input);
 111             assertEquals(sorted, input);
 112         }
 113     }
 114 
 115     @Test
 116     void test_rand() {
 117         int val = stdLibHelper.rand();
 118         for (int i = 0 ; i < 100 ; i++) {
 119             int newVal = stdLibHelper.rand();
 120             if (newVal != val) {
 121                 return; //ok
 122             }
 123             val = newVal;
 124         }
 125         fail("All values are the same! " + val);
 126     }
 127 
 128     @Test(dataProvider = "printfArgs")
 129     void test_printf(List<PrintfArg> args) {
 130         try (Scope s = Scope.newNativeScope()) {
 131             String formatStr = args.stream()
 132                     .map(a -> a.format)
 133                     .collect(Collectors.joining(","));
 134 
 135             int formattedArgsLength = args.stream()
 136                     .mapToInt(a -> a.length)
 137                     .sum();
 138 
 139             Object[] argValues = args.stream()
 140                     .map(a -> a.valueFunc.apply(s))
 141                     .toArray();
 142 
 143             int delimCount = (args.size() > 0) ?
 144                     args.size() - 1 :
 145                     0;
 146 
 147             assertEquals(stdLibHelper.printf("hello(" + formatStr + ")\n", argValues), 8 + formattedArgsLength + delimCount);
 148         }
 149     }
 150 
 151     static class StdLibHelper {
 152         StdLib stdLib = Libraries.bind(MethodHandles.lookup(), StdLib.class);
 153 
 154         String strcat(String s1, String s2) {
 155             try (Scope s = Scope.newNativeScope()) {
 156                 Pointer<Byte> buf = s.allocate(NativeTypes.INT8, s1.length() + s2.length() + 1);
 157                 Pointer<Byte> base = buf;
 158                 for (char c : s1.toCharArray()) {
 159                     buf.set((byte)c);
 160                     buf = buf.offset(1);
 161                 }
 162                 buf.set((byte)'\0');
 163                 return Pointer.toString(stdLib.strcat(base, s.allocateCString(s2)));
 164             }
 165         }
 166 
 167         int strcmp(String s1, String s2) {
 168             try (Scope s = Scope.newNativeScope()) {
 169                 return stdLib.strcmp(s.allocateCString(s1), s.allocateCString(s2));
 170             }
 171         }
 172 
 173         int puts(String msg) {
 174             try (Scope s = Scope.newNativeScope()) {
 175                 return stdLib.puts(s.allocateCString(msg));
 176             }
 177         }
 178 
 179         int strlen(String msg) {
 180             try (Scope s = Scope.newNativeScope()) {
 181                 return stdLib.strlen(s.allocateCString(msg));
 182             }
 183         }
 184 
 185         Pointer<Tm> localtime(Pointer<Time> arg) {
 186             return stdLib.localtime(arg);
 187         }
 188 
 189         int[] qsort(int[] array) {
 190             try (Scope s = Scope.newNativeScope()) {
 191                 //allocate the array
 192                 Array<Integer> arr = s.allocateArray(NativeTypes.INT32, array);
 193                 
 194                 //call the function
 195                 stdLib.qsort(arr.elementPointer(), array.length, 4,
 196                         s.allocateCallback((u1, u2) -> {
 197                                 int i1 = u1.get();
 198                                 int i2 = u2.get();
 199                                 return i1 - i2;
 200                         }));
 201                 //get result
 202                 return arr.toArray(int[]::new);
 203             }
 204         }
 205 
 206         int rand() {
 207             return stdLib.rand();
 208         }
 209 
 210         int printf(String format, Object... args) {
 211             try (Scope sc = Scope.newNativeScope()) {
 212                 return stdLib.printf(sc.allocateCString(format), args);
 213             }
 214         }
 215 
 216         Pointer<Void> fopen(String filename, String mode) {
 217             try (Scope s = Scope.newNativeScope()) {
 218                 return stdLib.fopen(s.allocateCString(filename), s.allocateCString(mode));
 219             }
 220         }
 221 
 222         @NativeHeader(declarations =
 223                 "puts=(u64:u8)i32" +
 224                 "strcat=(u64:u8u64:i8)u64:u8" +
 225                 "strcmp=(u64:u8u64:i8)i32" +
 226                 "strlen=(u64:u8)i32" +
 227                 "time=(u64:$(Time))$(Time)" +
 228                 "localtime=(u64:$(Time))u64:$(Tm)" +
 229                 "qsort=(u64:[0i32]i32i32u64:(u64:i32u64:i32)i32)v" +
 230                 "rand=()i32" +
 231                 "printf=(u64:u8*)i32" +
 232                 "fopen=(u64:u8u64:i8)u64:v")
 233         public interface StdLib {
 234             int puts(Pointer<Byte> str);
 235             Pointer<Byte> strcat(Pointer<Byte> s1, Pointer<Byte> s2);
 236             int strcmp(Pointer<Byte> s1, Pointer<Byte> s2);
 237             int strlen(Pointer<Byte> s2);
 238             Time time(Pointer<Time> arg);
 239             Pointer<Tm> localtime(Pointer<Time> arg);
 240             void qsort(Pointer<Integer> base, int nitems, int size, Callback<QsortComparator> comparator);
 241             int rand();
 242             int printf(Pointer<Byte> format, Object... args);
 243             Pointer<Void> fopen(Pointer<Byte> filename, Pointer<Byte> mode);
 244 
 245             @NativeCallback("(u64:i32u64:i32)i32")
 246             interface QsortComparator {
 247                 int compare(Pointer<Integer> u1, Pointer<Integer> u2);
 248             }
 249         }
 250 
 251         @NativeStruct("[" +
 252                 "   i64(get=seconds)(set=setSeconds)" +
 253                 "](Time)")
 254         public interface Time extends Struct<Time> {
 255             long seconds();
 256             void setSeconds(long secs);
 257         }
 258 
 259         @NativeStruct("[" +
 260                 "   i32(get=sec)" +
 261                 "   i32(get=min)" +
 262                 "   i32(get=hour)" +
 263                 "   i32(get=mday)" +
 264                 "   i32(get=mon)" +
 265                 "   i32(get=year)" +
 266                 "   i32(get=wday)" +
 267                 "   i32(get=yday)" +
 268                 "   i8(get=isdst)" +
 269                 "](Tm)")
 270         public interface Tm extends Struct<Tm> {
 271             int sec();
 272             int min();
 273             int hour();
 274             int mday();
 275             int mon();
 276             int year();
 277             int wday();
 278             int yday();
 279             boolean isdst();
 280         }
 281     }
 282 
 283     /*** data providers ***/
 284 
 285     @DataProvider
 286     public static Object[][] ints() {
 287         return perms(0, new Integer[] { 0, 1, 2, 3, 4 }).stream()
 288                 .map(l -> new Object[] { l })
 289                 .toArray(Object[][]::new);
 290     }
 291 
 292     @DataProvider
 293     public static Object[][] strings() {
 294         return perms(0, new String[] { "a", "b", "c" }).stream()
 295                 .map(l -> new Object[] { String.join("", l) })
 296                 .toArray(Object[][]::new);
 297     }
 298 
 299     @DataProvider
 300     public static Object[][] stringPairs() {
 301         Object[][] strings = strings();
 302         Object[][] stringPairs = new Object[strings.length * strings.length][];
 303         int pos = 0;
 304         for (Object[] s1 : strings) {
 305             for (Object[] s2 : strings) {
 306                 stringPairs[pos++] = new Object[] { s1[0], s2[0] };
 307             }
 308         }
 309         return stringPairs;
 310     }
 311 
 312     @DataProvider
 313     public static Object[][] instants() {
 314         long start = Timestamp.valueOf("2017-01-01 00:00:00").getTime();
 315         long end = Timestamp.valueOf("2017-12-31 00:00:00").getTime();
 316         Object[][] instants = new Object[100][];
 317         for (int i = 0 ; i < instants.length ; i++) {
 318             long instant = start + (long)(Math.random() * (end - start));
 319             instants[i] = new Object[] { Instant.ofEpochSecond(instant)};
 320         }
 321         return instants;
 322     }
 323 
 324     @DataProvider
 325     public static Object[][] printfArgs() {
 326         return perms(0, PrintfArg.values()).stream()
 327                 .map(l -> new Object[] { l })
 328                 .toArray(Object[][]::new);
 329     }
 330 
 331     enum PrintfArg {
 332         INTEGRAL("%d", s -> 42, 2),
 333         STRING("%s", s -> s.allocateCString("str"), 3),
 334         CHAR("%c", s -> 'h', 1),
 335         FLOAT("%.2f", s -> 1.23d, 4);
 336 
 337         String format;
 338         Function<Scope, Object> valueFunc;
 339         int length;
 340 
 341         PrintfArg(String format, Function<Scope, Object> valueFunc, int length) {
 342             this.format = format;
 343             this.valueFunc = valueFunc;
 344             this.length = length;
 345         }
 346     }
 347 
 348     static <Z> Set<List<Z>> perms(int count, Z[] arr) {
 349         if (count == arr.length) {
 350             return Set.of(List.of());
 351         } else {
 352             return Arrays.stream(arr)
 353                     .flatMap(num -> {
 354                         Set<List<Z>> perms = perms(count + 1, arr);
 355                         return Stream.concat(
 356                                 //take n
 357                                 perms.stream().map(l -> {
 358                                     List<Z> li = new ArrayList<>(l);
 359                                     li.add(num);
 360                                     return li;
 361                                 }),
 362                                 //drop n
 363                                 perms.stream());
 364                     }).collect(Collectors.toSet());
 365         }
 366     }
 367 }