1 /* 2 * Copyright (c) 2018, 2019, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 25 /* 26 * This file is used to generated optimized finite field implementations. 27 */ 28 package build.tools.intpoly; 29 30 import java.io.*; 31 import java.math.BigInteger; 32 import java.nio.file.Files; 33 import java.nio.file.Path; 34 import java.util.*; 35 36 public class FieldGen { 37 38 static FieldParams Curve25519 = new FieldParams( 39 "IntegerPolynomial25519", 26, 10, 1, 255, 40 Arrays.asList( 41 new Term(0, -19) 42 ), 43 Curve25519CrSequence(), simpleSmallCrSequence(10) 44 ); 45 46 private static List<CarryReduce> Curve25519CrSequence() { 47 List<CarryReduce> result = new ArrayList<CarryReduce>(); 48 49 // reduce(7,2) 50 result.add(new Reduce(17)); 51 result.add(new Reduce(18)); 52 53 // carry(8,2) 54 result.add(new Carry(8)); 55 result.add(new Carry(9)); 56 57 // reduce(0,7) 58 for (int i = 10; i < 17; i++) { 59 result.add(new Reduce(i)); 60 } 61 62 // carry(0,9) 63 result.addAll(fullCarry(10)); 64 65 return result; 66 } 67 68 static FieldParams Curve448 = new FieldParams( 69 "IntegerPolynomial448", 28, 16, 1, 448, 70 Arrays.asList( 71 new Term(224, -1), 72 new Term(0, -1) 73 ), 74 Curve448CrSequence(), simpleSmallCrSequence(16) 75 ); 76 77 private static List<CarryReduce> Curve448CrSequence() { 78 List<CarryReduce> result = new ArrayList<CarryReduce>(); 79 80 // reduce(8, 7) 81 for (int i = 24; i < 31; i++) { 82 result.add(new Reduce(i)); 83 } 84 // reduce(4, 4) 85 for (int i = 20; i < 24; i++) { 86 result.add(new Reduce(i)); 87 } 88 89 //carry(14, 2) 90 result.add(new Carry(14)); 91 result.add(new Carry(15)); 92 93 // reduce(0, 4) 94 for (int i = 16; i < 20; i++) { 95 result.add(new Reduce(i)); 96 } 97 98 // carry(0, 15) 99 result.addAll(fullCarry(16)); 100 101 return result; 102 } 103 104 static FieldParams P256 = new FieldParams( 105 "IntegerPolynomialP256", 26, 10, 2, 256, 106 Arrays.asList( 107 new Term(224, -1), 108 new Term(192, 1), 109 new Term(96, 1), 110 new Term(0, -1) 111 ), 112 P256CrSequence(), simpleSmallCrSequence(10) 113 ); 114 115 private static List<CarryReduce> P256CrSequence() { 116 List<CarryReduce> result = new ArrayList<CarryReduce>(); 117 result.addAll(fullReduce(10)); 118 result.addAll(simpleSmallCrSequence(10)); 119 return result; 120 } 121 122 static FieldParams P384 = new FieldParams( 123 "IntegerPolynomialP384", 28, 14, 2, 384, 124 Arrays.asList( 125 new Term(128, -1), 126 new Term(96, -1), 127 new Term(32, 1), 128 new Term(0, -1) 129 ), 130 P384CrSequence(), simpleSmallCrSequence(14) 131 ); 132 133 private static List<CarryReduce> P384CrSequence() { 134 List<CarryReduce> result = new ArrayList<CarryReduce>(); 135 result.addAll(fullReduce(14)); 136 result.addAll(simpleSmallCrSequence(14)); 137 return result; 138 } 139 140 static FieldParams P521 = new FieldParams( 141 "IntegerPolynomialP521", 28, 19, 2, 521, 142 Arrays.asList( 143 new Term(0, -1) 144 ), 145 P521CrSequence(), simpleSmallCrSequence(19) 146 ); 147 148 private static List<CarryReduce> P521CrSequence() { 149 List<CarryReduce> result = new ArrayList<CarryReduce>(); 150 result.addAll(fullReduce(19)); 151 result.addAll(simpleSmallCrSequence(19)); 152 return result; 153 } 154 155 static FieldParams O256 = new FieldParams( 156 "P256OrderField", 26, 10, 1, 256, 157 "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551", 158 orderFieldCrSequence(10), orderFieldSmallCrSequence(10) 159 ); 160 161 static FieldParams O384 = new FieldParams( 162 "P384OrderField", 28, 14, 1, 384, 163 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFC7634D81F4372DDF581A0DB248B0A77AECEC196ACCC52973", 164 orderFieldCrSequence(14), orderFieldSmallCrSequence(14) 165 ); 166 167 static FieldParams O521 = new FieldParams( 168 "P521OrderField", 28, 19, 1, 521, 169 "01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFA51868783BF2F966B7FCC0148F709A5D03BB5C9B8899C47AEBB6FB71E91386409", 170 o521crSequence(19), orderFieldSmallCrSequence(19) 171 ); 172 173 private static List<CarryReduce> o521crSequence(int numLimbs) { 174 175 // split the full reduce in half, with a carry in between 176 List<CarryReduce> result = new ArrayList<CarryReduce>(); 177 result.addAll(fullCarry(2 * numLimbs)); 178 for (int i = 2 * numLimbs - 1; i >= numLimbs + numLimbs / 2; i--) { 179 result.add(new Reduce(i)); 180 } 181 // carry 182 for (int i = numLimbs; i < numLimbs + numLimbs / 2 - 1; i++) { 183 result.add(new Carry(i)); 184 } 185 // rest of reduce 186 for (int i = numLimbs + numLimbs / 2 - 1; i >= numLimbs; i--) { 187 result.add(new Reduce(i)); 188 } 189 result.addAll(orderFieldSmallCrSequence(numLimbs)); 190 191 return result; 192 } 193 194 private static List<CarryReduce> orderFieldCrSequence(int numLimbs) { 195 List<CarryReduce> result = new ArrayList<CarryReduce>(); 196 result.addAll(fullCarry(2 * numLimbs)); 197 result.add(new Reduce(2 * numLimbs - 1)); 198 result.addAll(fullReduce(numLimbs)); 199 result.addAll(fullCarry(numLimbs + 1)); 200 result.add(new Reduce(numLimbs)); 201 result.addAll(fullCarry(numLimbs)); 202 203 return result; 204 } 205 206 private static List<CarryReduce> orderFieldSmallCrSequence(int numLimbs) { 207 List<CarryReduce> result = new ArrayList<CarryReduce>(); 208 result.addAll(fullCarry(numLimbs + 1)); 209 result.add(new Reduce(numLimbs)); 210 result.addAll(fullCarry(numLimbs)); 211 return result; 212 } 213 214 static final FieldParams[] ALL_FIELDS = { 215 P256, P384, P521, O256, O384, O521, 216 }; 217 218 public static class Term { 219 private final int power; 220 private final int coefficient; 221 222 public Term(int power, int coefficient) { 223 this.power = power; 224 this.coefficient = coefficient; 225 } 226 227 public int getPower() { 228 return power; 229 } 230 231 public int getCoefficient() { 232 return coefficient; 233 } 234 235 public BigInteger getValue() { 236 return BigInteger.valueOf(2).pow(power) 237 .multiply(BigInteger.valueOf(coefficient)); 238 } 239 240 public String toString() { 241 return "2^" + power + " * " + coefficient; 242 } 243 } 244 245 static abstract class CarryReduce { 246 private final int index; 247 248 protected CarryReduce(int index) { 249 this.index = index; 250 } 251 252 public int getIndex() { 253 return index; 254 } 255 256 public abstract void write(CodeBuffer out, FieldParams params, 257 String prefix, Iterable<CarryReduce> remaining); 258 } 259 260 static class Carry extends CarryReduce { 261 public Carry(int index) { 262 super(index); 263 } 264 265 public void write(CodeBuffer out, FieldParams params, String prefix, 266 Iterable<CarryReduce> remaining) { 267 carry(out, params, prefix, getIndex()); 268 } 269 } 270 271 static class Reduce extends CarryReduce { 272 public Reduce(int index) { 273 super(index); 274 } 275 276 public void write(CodeBuffer out, FieldParams params, String prefix, 277 Iterable<CarryReduce> remaining) { 278 reduce(out, params, prefix, getIndex(), remaining); 279 } 280 } 281 282 static class FieldParams { 283 private final String className; 284 private final int bitsPerLimb; 285 private final int numLimbs; 286 private final int maxAdds; 287 private final int power; 288 private final Iterable<Term> terms; 289 private final List<CarryReduce> crSequence; 290 private final List<CarryReduce> smallCrSequence; 291 292 public FieldParams(String className, int bitsPerLimb, int numLimbs, 293 int maxAdds, int power, 294 Iterable<Term> terms, List<CarryReduce> crSequence, 295 List<CarryReduce> smallCrSequence) { 296 this.className = className; 297 this.bitsPerLimb = bitsPerLimb; 298 this.numLimbs = numLimbs; 299 this.maxAdds = maxAdds; 300 this.power = power; 301 this.terms = terms; 302 this.crSequence = crSequence; 303 this.smallCrSequence = smallCrSequence; 304 } 305 306 public FieldParams(String className, int bitsPerLimb, int numLimbs, 307 int maxAdds, int power, 308 String term, List<CarryReduce> crSequence, 309 List<CarryReduce> smallCrSequence) { 310 this.className = className; 311 this.bitsPerLimb = bitsPerLimb; 312 this.numLimbs = numLimbs; 313 this.maxAdds = maxAdds; 314 this.power = power; 315 this.crSequence = crSequence; 316 this.smallCrSequence = smallCrSequence; 317 318 terms = buildTerms(BigInteger.ONE.shiftLeft(power) 319 .subtract(new BigInteger(term, 16))); 320 } 321 322 private Iterable<Term> buildTerms(BigInteger sub) { 323 // split a large subtrahend into smaller terms 324 // that are aligned with limbs 325 List<Term> result = new ArrayList<Term>(); 326 BigInteger mod = BigInteger.valueOf(1 << bitsPerLimb); 327 int termIndex = 0; 328 while (!sub.equals(BigInteger.ZERO)) { 329 int coef = sub.mod(mod).intValue(); 330 boolean plusOne = false; 331 if (coef > (1 << (bitsPerLimb - 1))) { 332 coef = coef - (1 << bitsPerLimb); 333 plusOne = true; 334 } 335 if (coef != 0) { 336 int pow = termIndex * bitsPerLimb; 337 result.add(new Term(pow, -coef)); 338 } 339 sub = sub.shiftRight(bitsPerLimb); 340 if (plusOne) { 341 sub = sub.add(BigInteger.ONE); 342 } 343 ++termIndex; 344 } 345 return result; 346 } 347 348 public String getClassName() { 349 return className; 350 } 351 352 public int getBitsPerLimb() { 353 return bitsPerLimb; 354 } 355 356 public int getNumLimbs() { 357 return numLimbs; 358 } 359 360 public int getMaxAdds() { 361 return maxAdds; 362 } 363 364 public int getPower() { 365 return power; 366 } 367 368 public Iterable<Term> getTerms() { 369 return terms; 370 } 371 372 public List<CarryReduce> getCrSequence() { 373 return crSequence; 374 } 375 376 public List<CarryReduce> getSmallCrSequence() { 377 return smallCrSequence; 378 } 379 } 380 381 static Collection<Carry> fullCarry(int numLimbs) { 382 List<Carry> result = new ArrayList<Carry>(); 383 for (int i = 0; i < numLimbs - 1; i++) { 384 result.add(new Carry(i)); 385 } 386 return result; 387 } 388 389 static Collection<Reduce> fullReduce(int numLimbs) { 390 List<Reduce> result = new ArrayList<Reduce>(); 391 for (int i = numLimbs - 2; i >= 0; i--) { 392 result.add(new Reduce(i + numLimbs)); 393 } 394 return result; 395 } 396 397 static List<CarryReduce> simpleCrSequence(int numLimbs) { 398 List<CarryReduce> result = new ArrayList<CarryReduce>(); 399 for (int i = 0; i < 4; i++) { 400 result.addAll(fullCarry(2 * numLimbs - 1)); 401 result.addAll(fullReduce(numLimbs)); 402 } 403 404 return result; 405 } 406 407 static List<CarryReduce> simpleSmallCrSequence(int numLimbs) { 408 List<CarryReduce> result = new ArrayList<CarryReduce>(); 409 // carry a few positions at the end 410 for (int i = numLimbs - 2; i < numLimbs; i++) { 411 result.add(new Carry(i)); 412 } 413 // this carries out a single value that must be reduced back in 414 result.add(new Reduce(numLimbs)); 415 // finish with a full carry 416 result.addAll(fullCarry(numLimbs)); 417 return result; 418 } 419 420 private final String packageName; 421 private final String parentName; 422 423 private final Path headerPath; 424 private final Path destPath; 425 426 public FieldGen(String packageName, String parentName, 427 Path headerPath, Path destRoot) throws IOException { 428 this.packageName = packageName; 429 this.parentName = parentName; 430 this.headerPath = headerPath; 431 this.destPath = destRoot.resolve(packageName.replace(".", "/")); 432 Files.createDirectories(destPath); 433 } 434 435 // args: header.txt destpath 436 public static void main(String[] args) throws Exception { 437 438 FieldGen gen = new FieldGen( 439 "sun.security.util.math.intpoly", 440 "IntegerPolynomial", 441 Path.of(args[0]), 442 Path.of(args[1])); 443 for (FieldParams p : ALL_FIELDS) { 444 System.out.println(p.className); 445 System.out.println(p.terms); 446 System.out.println(); 447 gen.generateFile(p); 448 } 449 } 450 451 private void generateFile(FieldParams params) throws IOException { 452 String text = generate(params); 453 String fileName = params.getClassName() + ".java"; 454 PrintWriter out = new PrintWriter(Files.newBufferedWriter( 455 destPath.resolve(fileName))); 456 out.println(text); 457 out.close(); 458 } 459 460 static class CodeBuffer { 461 462 private int nextTemporary = 0; 463 private Set<String> temporaries = new HashSet<String>(); 464 private StringBuffer buffer = new StringBuffer(); 465 private int indent = 0; 466 private Class<?> lastCR; 467 private int lastCrCount = 0; 468 private int crMethodBreakCount = 0; 469 private int crNumLimbs = 0; 470 471 public void incrIndent() { 472 indent++; 473 } 474 475 public void decrIndent() { 476 indent--; 477 } 478 479 public void newTempScope() { 480 nextTemporary = 0; 481 temporaries.clear(); 482 } 483 484 public void appendLine(String s) { 485 appendIndent(); 486 buffer.append(s + "\n"); 487 } 488 489 public void appendLine() { 490 buffer.append("\n"); 491 } 492 493 public String toString() { 494 return buffer.toString(); 495 } 496 497 public void startCrSequence(int numLimbs) { 498 this.crNumLimbs = numLimbs; 499 lastCrCount = 0; 500 crMethodBreakCount = 0; 501 lastCR = null; 502 } 503 504 /* 505 * Record a carry/reduce of the specified type. This method is used to 506 * break up large carry/reduce sequences into multiple methods to make 507 * JIT/optimization easier 508 */ 509 public void record(Class<?> type) { 510 if (type == lastCR) { 511 lastCrCount++; 512 } else { 513 514 if (lastCrCount >= 8) { 515 insertCrMethodBreak(); 516 } 517 518 lastCR = type; 519 lastCrCount = 0; 520 } 521 } 522 523 private void insertCrMethodBreak() { 524 525 appendLine(); 526 527 // call the new method 528 appendIndent(); 529 append("carryReduce" + crMethodBreakCount + "(r"); 530 for (int i = 0; i < crNumLimbs; i++) { 531 append(", c" + i); 532 } 533 // temporaries are not live between operations, no need to send 534 append(");\n"); 535 536 decrIndent(); 537 appendLine("}"); 538 539 // make the method 540 appendIndent(); 541 append("void carryReduce" + crMethodBreakCount + "(long[] r"); 542 for (int i = 0; i < crNumLimbs; i++) { 543 append(", long c" + i); 544 } 545 append(") {\n"); 546 incrIndent(); 547 // declare temporaries 548 for (String temp : temporaries) { 549 appendLine("long " + temp + ";"); 550 } 551 append("\n"); 552 553 crMethodBreakCount++; 554 } 555 556 public String getTemporary(String type, String value) { 557 Iterator<String> iter = temporaries.iterator(); 558 if (iter.hasNext()) { 559 String result = iter.next(); 560 iter.remove(); 561 appendLine(result + " = " + value + ";"); 562 return result; 563 } else { 564 String result = "t" + (nextTemporary++); 565 appendLine(type + " " + result + " = " + value + ";"); 566 return result; 567 } 568 } 569 570 public void freeTemporary(String temp) { 571 temporaries.add(temp); 572 } 573 574 public void appendIndent() { 575 for (int i = 0; i < indent; i++) { 576 buffer.append(" "); 577 } 578 } 579 580 public void append(String s) { 581 buffer.append(s); 582 } 583 } 584 585 private String generate(FieldParams params) throws IOException { 586 CodeBuffer result = new CodeBuffer(); 587 String header = readHeader(); 588 result.appendLine(header); 589 590 if (packageName != null) { 591 result.appendLine("package " + packageName + ";"); 592 result.appendLine(); 593 } 594 result.appendLine("import java.math.BigInteger;"); 595 596 result.appendLine("public class " + params.getClassName() 597 + " extends " + this.parentName + " {"); 598 result.incrIndent(); 599 600 result.appendLine("private static final int BITS_PER_LIMB = " 601 + params.getBitsPerLimb() + ";"); 602 result.appendLine("private static final int NUM_LIMBS = " 603 + params.getNumLimbs() + ";"); 604 result.appendLine("private static final int MAX_ADDS = " 605 + params.getMaxAdds() + ";"); 606 result.appendLine( 607 "public static final BigInteger MODULUS = evaluateModulus();"); 608 result.appendLine("private static final long CARRY_ADD = 1 << " 609 + (params.getBitsPerLimb() - 1) + ";"); 610 if (params.getBitsPerLimb() * params.getNumLimbs() != params.getPower()) { 611 result.appendLine("private static final int LIMB_MASK = -1 " 612 + ">>> (64 - BITS_PER_LIMB);"); 613 } 614 int termIndex = 0; 615 616 result.appendLine("public " + params.getClassName() + "() {"); 617 result.appendLine(); 618 result.appendLine(" super(BITS_PER_LIMB, NUM_LIMBS, MAX_ADDS, MODULUS);"); 619 result.appendLine(); 620 result.appendLine("}"); 621 622 result.appendLine("private static BigInteger evaluateModulus() {"); 623 result.incrIndent(); 624 result.appendLine("BigInteger result = BigInteger.valueOf(2).pow(" 625 + params.getPower() + ");"); 626 for (Term t : params.getTerms()) { 627 boolean subtract = false; 628 int coefValue = t.getCoefficient(); 629 if (coefValue < 0) { 630 coefValue = 0 - coefValue; 631 subtract = true; 632 } 633 String coefExpr = "BigInteger.valueOf(" + coefValue + ")"; 634 String powExpr = "BigInteger.valueOf(2).pow(" + t.getPower() + ")"; 635 String termExpr = "ERROR"; 636 if (t.getPower() == 0) { 637 termExpr = coefExpr; 638 } else if (coefValue == 1) { 639 termExpr = powExpr; 640 } else { 641 termExpr = powExpr + ".multiply(" + coefExpr + ")"; 642 } 643 if (subtract) { 644 result.appendLine("result = result.subtract(" + termExpr + ");"); 645 } else { 646 result.appendLine("result = result.add(" + termExpr + ");"); 647 } 648 } 649 result.appendLine("return result;"); 650 result.decrIndent(); 651 result.appendLine("}"); 652 653 result.appendLine("@Override"); 654 result.appendLine("protected void finalCarryReduceLast(long[] limbs) {"); 655 result.incrIndent(); 656 int extraBits = params.getBitsPerLimb() * params.getNumLimbs() 657 - params.getPower(); 658 int highBits = params.getBitsPerLimb() - extraBits; 659 result.appendLine("long c = limbs[" + (params.getNumLimbs() - 1) 660 + "] >> " + highBits + ";"); 661 result.appendLine("limbs[" + (params.getNumLimbs() - 1) + "] -= c << " 662 + highBits + ";"); 663 for (Term t : params.getTerms()) { 664 int reduceBits = params.getPower() + extraBits - t.getPower(); 665 int negatedCoefficient = -1 * t.getCoefficient(); 666 modReduceInBits(result, params, true, "limbs", params.getNumLimbs(), 667 reduceBits, negatedCoefficient, "c"); 668 } 669 result.decrIndent(); 670 result.appendLine("}"); 671 672 // full carry/reduce sequence 673 result.appendIndent(); 674 result.append("private void carryReduce(long[] r, "); 675 for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) { 676 result.append("long c" + i); 677 if (i < 2 * params.getNumLimbs() - 2) { 678 result.append(", "); 679 } 680 } 681 result.append(") {\n"); 682 result.newTempScope(); 683 result.incrIndent(); 684 result.appendLine("long c" + (2 * params.getNumLimbs() - 1) + " = 0;"); 685 write(result, params.getCrSequence(), params, "c", 686 2 * params.getNumLimbs()); 687 result.appendLine(); 688 for (int i = 0; i < params.getNumLimbs(); i++) { 689 result.appendLine("r[" + i + "] = c" + i + ";"); 690 } 691 result.decrIndent(); 692 result.appendLine("}"); 693 694 // small carry/reduce sequence 695 result.appendIndent(); 696 result.append("private void carryReduce(long[] r, "); 697 for (int i = 0; i < params.getNumLimbs(); i++) { 698 result.append("long c" + i); 699 if (i < params.getNumLimbs() - 1) { 700 result.append(", "); 701 } 702 } 703 result.append(") {\n"); 704 result.newTempScope(); 705 result.incrIndent(); 706 result.appendLine("long c" + params.getNumLimbs() + " = 0;"); 707 write(result, params.getSmallCrSequence(), params, 708 "c", params.getNumLimbs() + 1); 709 result.appendLine(); 710 for (int i = 0; i < params.getNumLimbs(); i++) { 711 result.appendLine("r[" + i + "] = c" + i + ";"); 712 } 713 result.decrIndent(); 714 result.appendLine("}"); 715 716 result.appendLine("@Override"); 717 result.appendLine("protected void mult(long[] a, long[] b, long[] r) {"); 718 result.incrIndent(); 719 for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) { 720 result.appendIndent(); 721 result.append("long c" + i + " = "); 722 int startJ = Math.max(i + 1 - params.getNumLimbs(), 0); 723 int endJ = Math.min(params.getNumLimbs(), i + 1); 724 for (int j = startJ; j < endJ; j++) { 725 int bIndex = i - j; 726 result.append("(a[" + j + "] * b[" + bIndex + "])"); 727 if (j < endJ - 1) { 728 result.append(" + "); 729 } 730 } 731 result.append(";\n"); 732 } 733 result.appendLine(); 734 result.appendIndent(); 735 result.append("carryReduce(r, "); 736 for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) { 737 result.append("c" + i); 738 if (i < 2 * params.getNumLimbs() - 2) { 739 result.append(", "); 740 } 741 } 742 result.append(");\n"); 743 result.decrIndent(); 744 result.appendLine("}"); 745 746 result.appendLine("@Override"); 747 result.appendLine("protected void reduce(long[] a) {"); 748 result.incrIndent(); 749 result.appendIndent(); 750 result.append("carryReduce(a, "); 751 for (int i = 0; i < params.getNumLimbs(); i++) { 752 result.append("a[" + i + "]"); 753 if (i < params.getNumLimbs() - 1) { 754 result.append(", "); 755 } 756 } 757 result.append(");\n"); 758 result.decrIndent(); 759 result.appendLine("}"); 760 761 result.appendLine("@Override"); 762 result.appendLine("protected void square(long[] a, long[] r) {"); 763 result.incrIndent(); 764 for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) { 765 result.appendIndent(); 766 result.append("long c" + i + " = "); 767 int startJ = Math.max(i + 1 - params.getNumLimbs(), 0); 768 int endJ = Math.min(params.getNumLimbs(), i + 1); 769 int jDiff = endJ - startJ; 770 if (jDiff > 1) { 771 result.append("2 * ("); 772 } 773 for (int j = 0; j < jDiff / 2; j++) { 774 int aIndex = j + startJ; 775 int bIndex = i - aIndex; 776 result.append("(a[" + aIndex + "] * a[" + bIndex + "])"); 777 if (j < (jDiff / 2) - 1) { 778 result.append(" + "); 779 } 780 } 781 if (jDiff > 1) { 782 result.append(")"); 783 } 784 if (jDiff % 2 == 1) { 785 int aIndex = i / 2; 786 if (jDiff > 1) { 787 result.append(" + "); 788 } 789 result.append("(a[" + aIndex + "] * a[" + aIndex + "])"); 790 } 791 result.append(";\n"); 792 } 793 result.appendLine(); 794 result.appendIndent(); 795 result.append("carryReduce(r, "); 796 for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) { 797 result.append("c" + i); 798 if (i < 2 * params.getNumLimbs() - 2) { 799 result.append(", "); 800 } 801 } 802 result.append(");\n"); 803 result.decrIndent(); 804 result.appendLine("}"); 805 806 result.decrIndent(); 807 result.appendLine("}"); // end class 808 809 return result.toString(); 810 } 811 812 private static void write(CodeBuffer out, List<CarryReduce> sequence, 813 FieldParams params, String prefix, int numLimbs) { 814 815 out.startCrSequence(numLimbs); 816 for (int i = 0; i < sequence.size(); i++) { 817 CarryReduce cr = sequence.get(i); 818 Iterator<CarryReduce> remainingIter = sequence.listIterator(i + 1); 819 List<CarryReduce> remaining = new ArrayList<CarryReduce>(); 820 remainingIter.forEachRemaining(remaining::add); 821 cr.write(out, params, prefix, remaining); 822 } 823 } 824 825 private static void reduce(CodeBuffer out, FieldParams params, 826 String prefix, int index, Iterable<CarryReduce> remaining) { 827 828 out.record(Reduce.class); 829 830 out.appendLine("//reduce from position " + index); 831 String reduceFrom = indexedExpr(false, prefix, index); 832 boolean referenced = false; 833 for (CarryReduce cr : remaining) { 834 if (cr.index == index) { 835 referenced = true; 836 } 837 } 838 for (Term t : params.getTerms()) { 839 int reduceBits = params.getPower() - t.getPower(); 840 int negatedCoefficient = -1 * t.getCoefficient(); 841 modReduceInBits(out, params, false, prefix, index, reduceBits, 842 negatedCoefficient, reduceFrom); 843 } 844 if (referenced) { 845 out.appendLine(reduceFrom + " = 0;"); 846 } 847 } 848 849 private static void carry(CodeBuffer out, FieldParams params, 850 String prefix, int index) { 851 852 out.record(Carry.class); 853 854 out.appendLine("//carry from position " + index); 855 String carryFrom = prefix + index; 856 String carryTo = prefix + (index + 1); 857 String carry = "(" + carryFrom + " + CARRY_ADD) >> " 858 + params.getBitsPerLimb(); 859 String temp = out.getTemporary("long", carry); 860 out.appendLine(carryFrom + " -= (" + temp + " << " 861 + params.getBitsPerLimb() + ");"); 862 out.appendLine(carryTo + " += " + temp + ";"); 863 out.freeTemporary(temp); 864 } 865 866 private static String indexedExpr( 867 boolean isArray, String prefix, int index) { 868 String result = prefix + index; 869 if (isArray) { 870 result = prefix + "[" + index + "]"; 871 } 872 return result; 873 } 874 875 private static void modReduceInBits(CodeBuffer result, FieldParams params, 876 boolean isArray, String prefix, int index, int reduceBits, 877 int coefficient, String c) { 878 879 String x = coefficient + " * " + c; 880 String accOp = "+="; 881 String temp = null; 882 if (coefficient == 1) { 883 x = c; 884 } else if (coefficient == -1) { 885 x = c; 886 accOp = "-="; 887 } else { 888 temp = result.getTemporary("long", x); 889 x = temp; 890 } 891 892 if (reduceBits % params.getBitsPerLimb() == 0) { 893 int pos = reduceBits / params.getBitsPerLimb(); 894 result.appendLine(indexedExpr(isArray, prefix, (index - pos)) 895 + " " + accOp + " " + x + ";"); 896 } else { 897 int secondPos = reduceBits / params.getBitsPerLimb(); 898 int bitOffset = (secondPos + 1) * params.getBitsPerLimb() 899 - reduceBits; 900 int rightBitOffset = params.getBitsPerLimb() - bitOffset; 901 result.appendLine(indexedExpr(isArray, prefix, 902 (index - (secondPos + 1))) + " " + accOp 903 + " (" + x + " << " + bitOffset + ") & LIMB_MASK;"); 904 result.appendLine(indexedExpr(isArray, prefix, 905 (index - secondPos)) + " " + accOp + " " + x 906 + " >> " + rightBitOffset + ";"); 907 } 908 909 if (temp != null) { 910 result.freeTemporary(temp); 911 } 912 } 913 914 private String readHeader() throws IOException { 915 BufferedReader reader 916 = Files.newBufferedReader(headerPath); 917 StringBuffer result = new StringBuffer(); 918 reader.lines().forEach(s -> result.append(s + "\n")); 919 return result.toString(); 920 } 921 }