1 /*
   2  * Copyright (c) 2008, 2013, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package com.sun.scenario.effect.compiler.backend.sw.me;
  27 
  28 import java.util.Arrays;
  29 import java.util.HashMap;
  30 import java.util.List;
  31 import java.util.Map;
  32 import com.sun.scenario.effect.compiler.model.CoreSymbols;
  33 import com.sun.scenario.effect.compiler.model.FuncImpl;
  34 import com.sun.scenario.effect.compiler.model.Function;
  35 import com.sun.scenario.effect.compiler.model.Type;
  36 import com.sun.scenario.effect.compiler.tree.Expr;
  37 import com.sun.scenario.effect.compiler.tree.VariableExpr;
  38 
  39 import static com.sun.scenario.effect.compiler.backend.sw.me.MEBackend.*;
  40 import static com.sun.scenario.effect.compiler.model.Type.*;
  41 
  42 /**
  43  * Contains the C/fixed-point implementations for all core (built-in) functions.
  44  */
  45 class MEFuncImpls {
  46 
  47     private static Map<Function, FuncImpl> funcs = new HashMap<Function, FuncImpl>();
  48 
  49     static FuncImpl get(Function func) {
  50         return funcs.get(func);
  51     }
  52 
  53     static {
  54         // float4 sample(sampler s, float2 loc)
  55         declareFunctionSample(SAMPLER);
  56 
  57         // float4 sample(lsampler s, float2 loc)
  58         declareFunctionSample(LSAMPLER);
  59 
  60         // float4 sample(fsampler s, float2 loc)
  61         declareFunctionSample(FSAMPLER);
  62 
  63         // int intcast(float x)
  64         declareFunctionIntCast();
  65 
  66         // <ftype> min(<ftype> x, <ftype> y)
  67         // <ftype> min(<ftype> x, float y)
  68         declareOverloadsMinMax("min", "((x_tmp$1 < y_tmp$2) ? x_tmp$1 : y_tmp$2)");
  69 
  70         // <ftype> max(<ftype> x, <ftype> y)
  71         // <ftype> max(<ftype> x, float y)
  72         declareOverloadsMinMax("max", "((x_tmp$1 > y_tmp$2) ? x_tmp$1 : y_tmp$2)");
  73 
  74         // <ftype> clamp(<ftype> val, <ftype> min, <ftype> max)
  75         // <ftype> clamp(<ftype> val, float min, float max)
  76         declareOverloadsClamp();
  77 
  78         // <ftype> smoothstep(<ftype> min, <ftype> max, <ftype> val)
  79         // <ftype> smoothstep(float min, float max, <ftype> val)
  80         declareOverloadsSmoothstep();
  81 
  82         // <ftype> abs(<ftype> x)
  83         declareOverloadsSimple("abs", "abs(x_tmp$1)");
  84 
  85         // <ftype> floor(<ftype> x)
  86         declareOverloadsSimple("floor", "floor(x_tmp$1)");
  87 
  88         // <ftype> ceil(<ftype> x)
  89         declareOverloadsSimple("ceil", "ceil(x_tmp$1)");
  90 
  91         // <ftype> fract(<ftype> x)
  92         declareOverloadsSimple("fract", "(x_tmp$1 - floor(x_tmp$1))");
  93 
  94         // <ftype> sign(<ftype> x)
  95         declareOverloadsSimple("sign", "((x_tmp$1 < 0.f) ? -1.f : (x_tmp$1 > 0.f) ? 1.f : 0.f)");
  96 
  97         // <ftype> sqrt(<ftype> x)
  98         declareOverloadsSimple("sqrt", "sqrt(x_tmp$1)");
  99 
 100         // <ftype> sin(<ftype> x)
 101         declareOverloadsSimple("sin", "sin(x_tmp$1)");
 102 
 103         // <ftype> cos(<ftype> x)
 104         declareOverloadsSimple("cos", "cos(x_tmp$1)");
 105 
 106         // <ftype> tan(<ftype> x)
 107         declareOverloadsSimple("tan", "tan(x_tmp$1)");
 108 
 109         // <ftype> pow(<ftype> x, <ftype> y)
 110         declareOverloadsSimple2("pow", "pow(x_tmp$1, y_tmp$2)");
 111 
 112         // <ftype> mod(<ftype> x, <ftype> y)
 113         // <ftype> mod(<ftype> x, float y)
 114         declareOverloadsMinMax("mod", "(x_tmp$1 % y_tmp$2)");
 115 
 116         // float dot(<ftype> x, <ftype> y)
 117         declareOverloadsDot();
 118 
 119         // float distance(<ftype> x, <ftype> y)
 120         declareOverloadsDistance();
 121 
 122         // <ftype> mix(<ftype> x, <ftype> y, <ftype> a)
 123         // <ftype> mix(<ftype> x, <ftype> y, float a)
 124         declareOverloadsMix();
 125 
 126         // <ftype> normalize(<ftype> x)
 127         declareOverloadsNormalize();
 128 
 129         // <ftype> ddx(<ftype> p)
 130         declareOverloadsSimple("ddx", "<ddx() not implemented for sw backends>");
 131 
 132         // <ftype> ddy(<ftype> p)
 133         declareOverloadsSimple("ddy", "<ddy() not implemented for sw backends>");
 134     }
 135 
 136     private static void declareFunction(FuncImpl impl,
 137                                         String name, Type... ptypes)
 138     {
 139         Function f = CoreSymbols.getFunction(name, Arrays.asList(ptypes));
 140         if (f == null) {
 141             throw new InternalError("Core function not found (have you declared the function in CoreSymbols?)");
 142         }
 143         funcs.put(f, impl);
 144     }
 145 
 146     /**
 147      * Used to declare sample function:
 148      *   float4 sample([l,f]sampler s, float2 loc)
 149      */
 150     private static void declareFunctionSample(final Type type) {
 151         FuncImpl fimpl = new FuncImpl() {
 152             @Override
 153             public String getPreamble(List<Expr> params) {
 154                 String s = getSamplerName(params);
 155                 // TODO: this bounds checking is way too costly...
 156                 String p = getPosName(params);
 157                 if (type == LSAMPLER) {
 158                     return
 159                         "lsample(" + s + ", loc_tmp_x, loc_tmp_y,\n" +
 160                         "        " + p + "w, " + p + "h, " + p + "scan,\n" +
 161                         "        " + s + "_vals);\n";
 162                 } else if (type == FSAMPLER) {
 163                     return
 164                         "float *" + s + "_arr_tmp = NULL;\n" +
 165                         "int iloc_tmp = 0;\n" +
 166                         "if (loc_tmp_x >= 0 && loc_tmp_y >= 0) {\n" +
 167                         "    int iloc_tmp_x = (int)(loc_tmp_x*" + p + "w);\n" +
 168                         "    int iloc_tmp_y = (int)(loc_tmp_y*" + p + "h);\n" +
 169                         "    jboolean out =\n" +
 170                         "        iloc_tmp_x >= " + p + "w ||\n" +
 171                         "        iloc_tmp_y >= " + p + "h;\n" +
 172                         "    if (!out) {\n" +
 173                         "        "+ s + "_arr_tmp = " + s + ";\n" +
 174                         "        iloc_tmp = 4 * (iloc_tmp_y*" + p + "scan + iloc_tmp_x);\n" +
 175                         "    }\n" +
 176                         "}\n";
 177                 } else {
 178                     return
 179                         "int " + s + "_tmp;\n" +
 180                         "if (loc_tmp_x >= 0 && loc_tmp_y >= 0) {\n" +
 181                         "    int iloc_tmp_x = (int)(loc_tmp_x*" + p + "w);\n" +
 182                         "    int iloc_tmp_y = (int)(loc_tmp_y*" + p + "h);\n" +
 183                         "    jboolean out =\n" +
 184                         "        iloc_tmp_x >= " + p + "w ||\n" +
 185                         "        iloc_tmp_y >= " + p + "h;\n" +
 186                         "    " + s + "_tmp = out ? 0 :\n" +
 187                         "        " + s + "[iloc_tmp_y*" + p + "scan + iloc_tmp_x];\n" +
 188                         "} else {\n" +
 189                         "    " + s + "_tmp = 0;\n" +
 190                         "}\n";
 191                 }
 192             }
 193             public String toString(int i, List<Expr> params) {
 194                 String s = getSamplerName(params);
 195                 if (type == LSAMPLER) {
 196                     return (i < 0 || i > 3) ? null : s + "_vals[" + i + "]";
 197                 } else if (type == FSAMPLER) {
 198                     String arr = s + "_arr_tmp";
 199                     switch (i) {
 200                     case 0:
 201                         return arr + " == NULL ? 0.f : " + arr + "[iloc_tmp]";
 202                     case 1:
 203                         return arr + " == NULL ? 0.f : " + arr + "[iloc_tmp+1]";
 204                     case 2:
 205                         return arr + " == NULL ? 0.f : " + arr + "[iloc_tmp+2]";
 206                     case 3:
 207                         return arr + " == NULL ? 0.f : " + arr + "[iloc_tmp+3]";
 208                     default:
 209                         return null;
 210                     }
 211                 } else {
 212                     switch (i) {
 213                     case 0:
 214                         return "(((" + s + "_tmp >> 16) & 0xff) / 255.f)";
 215                     case 1:
 216                         return "(((" + s + "_tmp >>  8) & 0xff) / 255.f)";
 217                     case 2:
 218                         return "(((" + s + "_tmp      ) & 0xff) / 255.f)";
 219                     case 3:
 220                         return "(((" + s + "_tmp >> 24) & 0xff) / 255.f)";
 221                     default:
 222                         return null;
 223                     }
 224                 }
 225             }
 226             private String getSamplerName(List<Expr> params) {
 227                 VariableExpr e = (VariableExpr)params.get(0);
 228                 return e.getVariable().getName();
 229             }
 230             private String getPosName(List<Expr> params) {
 231                 VariableExpr e = (VariableExpr)params.get(0);
 232                 return "src" + e.getVariable().getReg();
 233             }
 234         };
 235         declareFunction(fimpl, "sample", type, FLOAT2);
 236     }
 237 
 238     /**
 239      * Used to declare intcast function:
 240      *   int intcast(float x)
 241      */
 242     private static void declareFunctionIntCast() {
 243         FuncImpl fimpl = new FuncImpl() {
 244             public String toString(int i, List<Expr> params) {
 245                 return "((int)x_tmp)";
 246             }
 247         };
 248         declareFunction(fimpl, "intcast", FLOAT);
 249     }
 250 
 251     /**
 252      * Used to declare simple functions of the following form:
 253      *   <ftype> name(<ftype> x)
 254      */
 255     private static void declareOverloadsSimple(String name, final String pattern) {
 256         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 257             final boolean useSuffix = (type != FLOAT);
 258             FuncImpl fimpl = new FuncImpl() {
 259                 public String toString(int i, List<Expr> params) {
 260                     String sfx = useSuffix ? getSuffix(i) : "";
 261                     String s = pattern;
 262                     s = s.replace("$1", sfx);
 263                     return s;
 264                 }
 265             };
 266             declareFunction(fimpl, name, type);
 267         }
 268     }
 269 
 270     /**
 271      * Used to declare simple two parameter functions of the following form:
 272      *   <ftype> name(<ftype> x, <ftype> y)
 273      */
 274     private static void declareOverloadsSimple2(String name, final String pattern) {
 275         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 276             // declare (vectype,vectype) variants
 277             final boolean useSuffix = (type != FLOAT);
 278             FuncImpl fimpl = new FuncImpl() {
 279                 public String toString(int i, List<Expr> params) {
 280                     String sfx = useSuffix ? getSuffix(i) : "";
 281                     String s = pattern;
 282                     s = s.replace("$1", sfx);
 283                     s = s.replace("$2", sfx);
 284                     return s;
 285                 }
 286             };
 287             declareFunction(fimpl, name, type, type);
 288         }
 289     }
 290 
 291     /**
 292      * Used to declare normalize functions of the following form:
 293      *   <ftype> normalize(<ftype> x)
 294      */
 295     private static void declareOverloadsNormalize() {
 296         final String name = "normalize";
 297         final String pattern = "x_tmp$1 / denom";
 298         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 299             int n = type.getNumFields();
 300             final String preamble;
 301             if (n == 1) {
 302                 preamble = "float denom = x_tmp;\n";
 303             } else {
 304                 String     s  =    "(x_tmp_x * x_tmp_x)";
 305                            s += "+\n(x_tmp_y * x_tmp_y)";
 306                 if (n > 2) s += "+\n(x_tmp_z * x_tmp_z)";
 307                 if (n > 3) s += "+\n(x_tmp_w * x_tmp_w)";
 308                 preamble = "float denom = sqrt(" + s + ");\n";
 309             }
 310 
 311             final boolean useSuffix = (type != FLOAT);
 312             FuncImpl fimpl = new FuncImpl() {
 313                 @Override
 314                 public String getPreamble(List<Expr> params) {
 315                     return preamble;
 316                 }
 317                 public String toString(int i, List<Expr> params) {
 318                     String sfx = useSuffix ? getSuffix(i) : "";
 319                     String s = pattern;
 320                     s = s.replace("$1", sfx);
 321                     return s;
 322                 }
 323             };
 324             declareFunction(fimpl, name, type);
 325         }
 326     }
 327 
 328     /**
 329      * Used to declare dot functions of the following form:
 330      *   float dot(<ftype> x, <ftype> y)
 331      */
 332     private static void declareOverloadsDot() {
 333         final String name = "dot";
 334         for (final Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 335             int n = type.getNumFields();
 336             String s;
 337             if (n == 1) {
 338                 s = "(x_tmp * y_tmp)";
 339             } else {
 340                            s  =    "(x_tmp_x * y_tmp_x)";
 341                            s += "+\n(x_tmp_y * y_tmp_y)";
 342                 if (n > 2) s += "+\n(x_tmp_z * y_tmp_z)";
 343                 if (n > 3) s += "+\n(x_tmp_w * y_tmp_w)";
 344             }
 345             final String str = s;
 346             FuncImpl fimpl = new FuncImpl() {
 347                 public String toString(int i, List<Expr> params) {
 348                     return str;
 349                 }
 350             };
 351             declareFunction(fimpl, name, type, type);
 352         }
 353     }
 354 
 355     /**
 356      * Used to declare distance functions of the following form:
 357      *   float distance(<ftype> x, <ftype> y)
 358      */
 359     private static void declareOverloadsDistance() {
 360         final String name = "distance";
 361         for (final Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 362             int n = type.getNumFields();
 363             String s;
 364             if (n == 1) {
 365                 s = "(x_tmp - y_tmp) * (x_tmp - y_tmp)";
 366             } else {
 367                            s  =    "((x_tmp_x - y_tmp_x) * (x_tmp_x - y_tmp_x))";
 368                            s += "+\n((x_tmp_y - y_tmp_y) * (x_tmp_y - y_tmp_y))";
 369                 if (n > 2) s += "+\n((x_tmp_z - y_tmp_z) * (x_tmp_z - y_tmp_z))";
 370                 if (n > 3) s += "+\n((x_tmp_w - y_tmp_w) * (x_tmp_w - y_tmp_w))";
 371             }
 372             final String str = "sqrt(" + s + ")";
 373             FuncImpl fimpl = new FuncImpl() {
 374                 public String toString(int i, List<Expr> params) {
 375                     return str;
 376                 }
 377             };
 378             declareFunction(fimpl, name, type, type);
 379         }
 380     }
 381 
 382     /**
 383      * Used to declare min/max functions of the following form:
 384      *   <ftype> name(<ftype> x, <ftype> y)
 385      *   <ftype> name(<ftype> x, float y)
 386      *
 387      * TODO: this is currently geared to simple functions like
 388      * min and max; we should make this more general...
 389      */
 390     private static void declareOverloadsMinMax(String name, final String pattern) {
 391         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 392             // declare (vectype,vectype) variants
 393             final boolean useSuffix = (type != FLOAT);
 394             FuncImpl fimpl = new FuncImpl() {
 395                 public String toString(int i, List<Expr> params) {
 396                     String sfx = useSuffix ? getSuffix(i) : "";
 397                     String s = pattern;
 398                     s = s.replace("$1", sfx);
 399                     s = s.replace("$2", sfx);
 400                     return s;
 401                 }
 402             };
 403             declareFunction(fimpl, name, type, type);
 404 
 405             if (type == FLOAT) {
 406                 continue;
 407             }
 408 
 409             // declare (vectype,float) variants
 410             fimpl = new FuncImpl() {
 411                 public String toString(int i, List<Expr> params) {
 412                     String sfx = getSuffix(i);
 413                     String s = pattern;
 414                     s = s.replace("$1", sfx);
 415                     s = s.replace("$2", "");
 416                     return s;
 417                 }
 418             };
 419             declareFunction(fimpl, name, type, FLOAT);
 420         }
 421     }
 422 
 423     /**
 424      * Used to declare clamp functions of the following form:
 425      *   <ftype> clamp(<ftype> val, <ftype> min, <ftype> max)
 426      *   <ftype> clamp(<ftype> val, float min, float max)
 427      */
 428     private static void declareOverloadsClamp() {
 429         final String name = "clamp";
 430         final String pattern =
 431             "(val_tmp$1 < min_tmp$2) ? min_tmp$2 : \n" +
 432             "(val_tmp$1 > max_tmp$2) ? max_tmp$2 : val_tmp$1";
 433 
 434         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 435             // declare (vectype,vectype,vectype) variants
 436             final boolean useSuffix = (type != FLOAT);
 437             FuncImpl fimpl = new FuncImpl() {
 438                 public String toString(int i, List<Expr> params) {
 439                     String sfx = useSuffix ? getSuffix(i) : "";
 440                     String s = pattern;
 441                     s = s.replace("$1", sfx);
 442                     s = s.replace("$2", sfx);
 443                     return s;
 444                 }
 445             };
 446             declareFunction(fimpl, name, type, type, type);
 447 
 448             if (type == FLOAT) {
 449                 continue;
 450             }
 451 
 452             // declare (vectype,float,float) variants
 453             fimpl = new FuncImpl() {
 454                 public String toString(int i, List<Expr> params) {
 455                     String sfx = getSuffix(i);
 456                     String s = pattern;
 457                     s = s.replace("$1", sfx);
 458                     s = s.replace("$2", "");
 459                     return s;
 460                 }
 461             };
 462             declareFunction(fimpl, name, type, FLOAT, FLOAT);
 463         }
 464     }
 465 
 466     /**
 467      * Used to declare smoothstep functions of the following form:
 468      *   <ftype> smoothstep(<ftype> min, <ftype> max, <ftype> val)
 469      *   <ftype> smoothstep(float min, float max, <ftype> val)
 470      */
 471     private static void declareOverloadsSmoothstep() {
 472         final String name = "smoothstep";
 473         // TODO - the smoothstep function is defined to use Hermite interpolation
 474         final String pattern =
 475             "(val_tmp$1 < min_tmp$2) ? 0.0f : \n" +
 476             "(val_tmp$1 > max_tmp$2) ? 1.0f : \n" +
 477             "(val_tmp$1 / (max_tmp$2 - min_tmp$2))";
 478 
 479         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 480             // declare (vectype,vectype,vectype) variants
 481             final boolean useSuffix = (type != FLOAT);
 482             FuncImpl fimpl = new FuncImpl() {
 483                 public String toString(int i, List<Expr> params) {
 484                     String sfx = useSuffix ? getSuffix(i) : "";
 485                     String s = pattern;
 486                     s = s.replace("$1", sfx);
 487                     s = s.replace("$2", sfx);
 488                     return s;
 489                 }
 490             };
 491             declareFunction(fimpl, name, type, type, type);
 492 
 493             if (type == FLOAT) {
 494                 continue;
 495             }
 496 
 497             // declare (float,float,vectype) variants
 498             fimpl = new FuncImpl() {
 499                 public String toString(int i, List<Expr> params) {
 500                     String sfx = getSuffix(i);
 501                     String s = pattern;
 502                     s = s.replace("$1", sfx);
 503                     s = s.replace("$2", "");
 504                     return s;
 505                 }
 506             };
 507             declareFunction(fimpl, name, FLOAT, FLOAT, type);
 508         }
 509     }
 510 
 511     /**
 512      * Used to declare mix functions of the following form:
 513      *   <ftype> mix(<ftype> x, <ftype> y, <ftype> a)
 514      *   <ftype> mix(<ftype> x, <ftype> y, float a)
 515      */
 516     private static void declareOverloadsMix() {
 517         final String name = "mix";
 518         final String pattern =
 519             "(x_tmp$1 * (1.0f - a_tmp$2) + y_tmp$1 * a_tmp$2)";
 520 
 521         for (Type type : new Type[] {FLOAT, FLOAT2, FLOAT3, FLOAT4}) {
 522             // declare (vectype,vectype,vectype) variants
 523             final boolean useSuffix = (type != FLOAT);
 524             FuncImpl fimpl = new FuncImpl() {
 525                 public String toString(int i, List<Expr> params) {
 526                     String sfx = useSuffix ? getSuffix(i) : "";
 527                     String s = pattern;
 528                     s = s.replace("$1", sfx);
 529                     s = s.replace("$2", sfx);
 530                     return s;
 531                 }
 532             };
 533             declareFunction(fimpl, name, type, type, type);
 534 
 535             if (type == FLOAT) {
 536                 continue;
 537             }
 538 
 539             // declare (vectype,vectype,float) variants
 540             fimpl = new FuncImpl() {
 541                 public String toString(int i, List<Expr> params) {
 542                     String sfx = getSuffix(i);
 543                     String s = pattern;
 544                     s = s.replace("$1", sfx);
 545                     s = s.replace("$2", "");
 546                     return s;
 547                 }
 548             };
 549             declareFunction(fimpl, name, type, type, FLOAT);
 550         }
 551     }
 552 }