1 /* 2 * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have 23 * questions. 24 */ 25 package jdk.incubator.vector; 26 27 import jdk.internal.vm.annotation.ForceInline; 28 29 import java.nio.ByteBuffer; 30 import java.nio.ByteOrder; 31 import java.util.concurrent.ThreadLocalRandom; 32 33 @SuppressWarnings("cast") 34 public abstract class ByteVector<S extends Vector.Shape> extends Vector<Byte,S> { 35 36 ByteVector() {} 37 38 // Unary operator 39 40 interface FUnOp { 41 byte apply(int i, byte a); 42 } 43 44 abstract ByteVector<S> uOp(FUnOp f); 45 46 abstract ByteVector<S> uOp(Mask<Byte, S> m, FUnOp f); 47 48 // Binary operator 49 50 interface FBinOp { 51 byte apply(int i, byte a, byte b); 52 } 53 54 abstract ByteVector<S> bOp(Vector<Byte,S> o, FBinOp f); 55 56 abstract ByteVector<S> bOp(Vector<Byte,S> o, Mask<Byte, S> m, FBinOp f); 57 58 // Trinary operator 59 60 interface FTriOp { 61 byte apply(int i, byte a, byte b, byte c); 62 } 63 64 abstract ByteVector<S> tOp(Vector<Byte,S> o1, Vector<Byte,S> o2, FTriOp f); 65 66 abstract ByteVector<S> tOp(Vector<Byte,S> o1, Vector<Byte,S> o2, Mask<Byte, S> m, FTriOp f); 67 68 // Reduction operator 69 70 abstract byte rOp(byte v, FBinOp f); 71 72 // Binary test 73 74 interface FBinTest { 75 boolean apply(int i, byte a, byte b); 76 } 77 78 abstract Mask<Byte, S> bTest(Vector<Byte,S> o, FBinTest f); 79 80 // Foreach 81 82 interface FUnCon { 83 void apply(int i, byte a); 84 } 85 86 abstract void forEach(FUnCon f); 87 88 abstract void forEach(Mask<Byte, S> m, FUnCon f); 89 90 // 91 92 @Override 93 public ByteVector<S> add(Vector<Byte,S> o) { 94 return bOp(o, (i, a, b) -> (byte) (a + b)); 95 } 96 97 public abstract ByteVector<S> add(byte o); 98 99 @Override 100 public ByteVector<S> add(Vector<Byte,S> o, Mask<Byte, S> m) { 101 return bOp(o, m, (i, a, b) -> (byte) (a + b)); 102 } 103 104 public abstract ByteVector<S> add(byte o, Mask<Byte, S> m); 105 106 @Override 107 public ByteVector<S> addSaturate(Vector<Byte,S> o) { 108 return bOp(o, (i, a, b) -> (byte) ((a >= Integer.MAX_VALUE || Integer.MAX_VALUE - b > a) ? Integer.MAX_VALUE : a + b)); 109 } 110 111 public abstract ByteVector<S> addSaturate(byte o); 112 113 @Override 114 public ByteVector<S> addSaturate(Vector<Byte,S> o, Mask<Byte, S> m) { 115 return bOp(o, m, (i, a, b) -> (byte) ((a >= Integer.MAX_VALUE || Integer.MAX_VALUE - b > a) ? Integer.MAX_VALUE : a + b)); 116 } 117 118 public abstract ByteVector<S> addSaturate(byte o, Mask<Byte, S> m); 119 120 @Override 121 public ByteVector<S> sub(Vector<Byte,S> o) { 122 return bOp(o, (i, a, b) -> (byte) (a - b)); 123 } 124 125 public abstract ByteVector<S> sub(byte o); 126 127 @Override 128 public ByteVector<S> sub(Vector<Byte,S> o, Mask<Byte, S> m) { 129 return bOp(o, m, (i, a, b) -> (byte) (a - b)); 130 } 131 132 public abstract ByteVector<S> sub(byte o, Mask<Byte, S> m); 133 134 @Override 135 public ByteVector<S> subSaturate(Vector<Byte,S> o) { 136 return bOp(o, (i, a, b) -> (byte) ((a >= Integer.MIN_VALUE || Integer.MIN_VALUE + b > a) ? Integer.MAX_VALUE : a - b)); 137 } 138 139 public abstract ByteVector<S> subSaturate(byte o); 140 141 @Override 142 public ByteVector<S> subSaturate(Vector<Byte,S> o, Mask<Byte, S> m) { 143 return bOp(o, m, (i, a, b) -> (byte) ((a >= Integer.MIN_VALUE || Integer.MIN_VALUE + b > a) ? Integer.MAX_VALUE : a - b)); 144 } 145 146 public abstract ByteVector<S> subSaturate(byte o, Mask<Byte, S> m); 147 148 @Override 149 public ByteVector<S> mul(Vector<Byte,S> o) { 150 return bOp(o, (i, a, b) -> (byte) (a * b)); 151 } 152 153 public abstract ByteVector<S> mul(byte o); 154 155 @Override 156 public ByteVector<S> mul(Vector<Byte,S> o, Mask<Byte, S> m) { 157 return bOp(o, m, (i, a, b) -> (byte) (a * b)); 158 } 159 160 public abstract ByteVector<S> mul(byte o, Mask<Byte, S> m); 161 162 @Override 163 public ByteVector<S> neg() { 164 return uOp((i, a) -> (byte) (-a)); 165 } 166 167 @Override 168 public ByteVector<S> neg(Mask<Byte, S> m) { 169 return uOp(m, (i, a) -> (byte) (-a)); 170 } 171 172 @Override 173 public ByteVector<S> abs() { 174 return uOp((i, a) -> (byte) Math.abs(a)); 175 } 176 177 @Override 178 public ByteVector<S> abs(Mask<Byte, S> m) { 179 return uOp(m, (i, a) -> (byte) Math.abs(a)); 180 } 181 182 @Override 183 public ByteVector<S> min(Vector<Byte,S> o) { 184 return bOp(o, (i, a, b) -> (a <= b) ? a : b); 185 } 186 187 public abstract ByteVector<S> min(byte o); 188 189 @Override 190 public ByteVector<S> max(Vector<Byte,S> o) { 191 return bOp(o, (i, a, b) -> (a >= b) ? a : b); 192 } 193 194 public abstract ByteVector<S> max(byte o); 195 196 @Override 197 public Mask<Byte, S> equal(Vector<Byte,S> o) { 198 return bTest(o, (i, a, b) -> a == b); 199 } 200 201 public abstract Mask<Byte, S> equal(byte o); 202 203 @Override 204 public Mask<Byte, S> notEqual(Vector<Byte,S> o) { 205 return bTest(o, (i, a, b) -> a != b); 206 } 207 208 public abstract Mask<Byte, S> notEqual(byte o); 209 210 @Override 211 public Mask<Byte, S> lessThan(Vector<Byte,S> o) { 212 return bTest(o, (i, a, b) -> a < b); 213 } 214 215 public abstract Mask<Byte, S> lessThan(byte o); 216 217 @Override 218 public Mask<Byte, S> lessThanEq(Vector<Byte,S> o) { 219 return bTest(o, (i, a, b) -> a <= b); 220 } 221 222 public abstract Mask<Byte, S> lessThanEq(byte o); 223 224 @Override 225 public Mask<Byte, S> greaterThan(Vector<Byte,S> o) { 226 return bTest(o, (i, a, b) -> a > b); 227 } 228 229 public abstract Mask<Byte, S> greaterThan(byte o); 230 231 @Override 232 public Mask<Byte, S> greaterThanEq(Vector<Byte,S> o) { 233 return bTest(o, (i, a, b) -> a >= b); 234 } 235 236 public abstract Mask<Byte, S> greaterThanEq(byte o); 237 238 @Override 239 public ByteVector<S> blend(Vector<Byte,S> o, Mask<Byte, S> m) { 240 return bOp(o, (i, a, b) -> m.getElement(i) ? b : a); 241 } 242 243 public abstract ByteVector<S> blend(byte o, Mask<Byte, S> m); 244 245 @Override 246 public abstract ByteVector<S> shuffle(Vector<Byte,S> o, Shuffle<Byte, S> m); 247 248 @Override 249 public abstract ByteVector<S> swizzle(Shuffle<Byte, S> m); 250 251 @Override 252 @ForceInline 253 public <T extends Shape> ByteVector<T> resize(Species<Byte, T> species) { 254 return (ByteVector<T>) species.resize(this); 255 } 256 257 @Override 258 public abstract ByteVector<S> rotateEL(int i); 259 260 @Override 261 public abstract ByteVector<S> rotateER(int i); 262 263 @Override 264 public abstract ByteVector<S> shiftEL(int i); 265 266 @Override 267 public abstract ByteVector<S> shiftER(int i); 268 269 270 public ByteVector<S> and(Vector<Byte,S> o) { 271 return bOp(o, (i, a, b) -> (byte) (a & b)); 272 } 273 274 public abstract ByteVector<S> and(byte o); 275 276 public ByteVector<S> and(Vector<Byte,S> o, Mask<Byte, S> m) { 277 return bOp(o, m, (i, a, b) -> (byte) (a & b)); 278 } 279 280 public abstract ByteVector<S> and(byte o, Mask<Byte, S> m); 281 282 public ByteVector<S> or(Vector<Byte,S> o) { 283 return bOp(o, (i, a, b) -> (byte) (a | b)); 284 } 285 286 public abstract ByteVector<S> or(byte o); 287 288 public ByteVector<S> or(Vector<Byte,S> o, Mask<Byte, S> m) { 289 return bOp(o, m, (i, a, b) -> (byte) (a | b)); 290 } 291 292 public abstract ByteVector<S> or(byte o, Mask<Byte, S> m); 293 294 public ByteVector<S> xor(Vector<Byte,S> o) { 295 return bOp(o, (i, a, b) -> (byte) (a ^ b)); 296 } 297 298 public abstract ByteVector<S> xor(byte o); 299 300 public ByteVector<S> xor(Vector<Byte,S> o, Mask<Byte, S> m) { 301 return bOp(o, m, (i, a, b) -> (byte) (a ^ b)); 302 } 303 304 public abstract ByteVector<S> xor(byte o, Mask<Byte, S> m); 305 306 public ByteVector<S> not() { 307 return uOp((i, a) -> (byte) (~a)); 308 } 309 310 public ByteVector<S> not(Mask<Byte, S> m) { 311 return uOp(m, (i, a) -> (byte) (~a)); 312 } 313 314 // logical shift left 315 public ByteVector<S> shiftL(Vector<Byte,S> o) { 316 return bOp(o, (i, a, b) -> (byte) (a << b)); 317 } 318 319 public ByteVector<S> shiftL(int s) { 320 return uOp((i, a) -> (byte) (a << s)); 321 } 322 323 public ByteVector<S> shiftL(Vector<Byte,S> o, Mask<Byte, S> m) { 324 return bOp(o, m, (i, a, b) -> (byte) (a << b)); 325 } 326 327 public ByteVector<S> shiftL(int s, Mask<Byte, S> m) { 328 return uOp(m, (i, a) -> (byte) (a << s)); 329 } 330 331 // logical, or unsigned, shift right 332 public ByteVector<S> shiftR(Vector<Byte,S> o) { 333 return bOp(o, (i, a, b) -> (byte) (a >>> b)); 334 } 335 336 public ByteVector<S> shiftR(int s) { 337 return uOp((i, a) -> (byte) (a >>> s)); 338 } 339 340 public ByteVector<S> shiftR(Vector<Byte,S> o, Mask<Byte, S> m) { 341 return bOp(o, m, (i, a, b) -> (byte) (a >>> b)); 342 } 343 344 public ByteVector<S> shiftR(int s, Mask<Byte, S> m) { 345 return uOp(m, (i, a) -> (byte) (a >>> s)); 346 } 347 348 // arithmetic, or signed, shift right 349 public ByteVector<S> ashiftR(Vector<Byte,S> o) { 350 return bOp(o, (i, a, b) -> (byte) (a >> b)); 351 } 352 353 public ByteVector<S> aShiftR(int s) { 354 return uOp((i, a) -> (byte) (a >> s)); 355 } 356 357 public ByteVector<S> ashiftR(Vector<Byte,S> o, Mask<Byte, S> m) { 358 return bOp(o, m, (i, a, b) -> (byte) (a >> b)); 359 } 360 361 public ByteVector<S> aShiftR(int s, Mask<Byte, S> m) { 362 return uOp(m, (i, a) -> (byte) (a >> s)); 363 } 364 365 public ByteVector<S> rotateL(int j) { 366 return uOp((i, a) -> (byte) Integer.rotateLeft(a, j)); 367 } 368 369 public ByteVector<S> rotateR(int j) { 370 return uOp((i, a) -> (byte) Integer.rotateRight(a, j)); 371 } 372 373 @Override 374 public void intoByteArray(byte[] a, int ix) { 375 ByteBuffer bb = ByteBuffer.wrap(a, ix, a.length - ix); 376 intoByteBuffer(bb); 377 } 378 379 @Override 380 public void intoByteArray(byte[] a, int ix, Mask<Byte, S> m) { 381 ByteBuffer bb = ByteBuffer.wrap(a, ix, a.length - ix); 382 intoByteBuffer(bb, m); 383 } 384 385 @Override 386 public void intoByteBuffer(ByteBuffer bb) { 387 ByteBuffer fb = bb; 388 forEach((i, a) -> fb.put(a)); 389 } 390 391 @Override 392 public void intoByteBuffer(ByteBuffer bb, Mask<Byte, S> m) { 393 ByteBuffer fb = bb; 394 forEach((i, a) -> { 395 if (m.getElement(i)) 396 fb.put(a); 397 else 398 fb.position(fb.position() + 1); 399 }); 400 } 401 402 @Override 403 public void intoByteBuffer(ByteBuffer bb, int ix) { 404 bb = bb.duplicate().position(ix); 405 ByteBuffer fb = bb; 406 forEach((i, a) -> fb.put(i, a)); 407 } 408 409 @Override 410 public void intoByteBuffer(ByteBuffer bb, int ix, Mask<Byte, S> m) { 411 bb = bb.duplicate().position(ix); 412 ByteBuffer fb = bb; 413 forEach(m, (i, a) -> fb.put(i, a)); 414 } 415 416 417 // Type specific horizontal reductions 418 419 public byte addAll() { 420 return rOp((byte) 0, (i, a, b) -> (byte) (a + b)); 421 } 422 423 public byte subAll() { 424 return rOp((byte) 0, (i, a, b) -> (byte) (a - b)); 425 } 426 427 public byte mulAll() { 428 return rOp((byte) 1, (i, a, b) -> (byte) (a * b)); 429 } 430 431 public byte minAll() { 432 return rOp(Byte.MAX_VALUE, (i, a, b) -> a > b ? b : a); 433 } 434 435 public byte maxAll() { 436 return rOp(Byte.MIN_VALUE, (i, a, b) -> a < b ? b : a); 437 } 438 439 public byte orAll() { 440 return rOp((byte) 0, (i, a, b) -> (byte) (a | b)); 441 } 442 443 public byte andAll() { 444 return rOp((byte) -1, (i, a, b) -> (byte) (a & b)); 445 } 446 447 public byte xorAll() { 448 return rOp((byte) 0, (i, a, b) -> (byte) (a ^ b)); 449 } 450 451 // Type specific accessors 452 453 public abstract byte get(int i); 454 455 public abstract ByteVector<S> with(int i, byte e); 456 457 // Type specific extractors 458 459 @ForceInline 460 public byte[] toArray() { 461 byte[] a = new byte[species().length()]; 462 intoArray(a, 0); 463 return a; 464 } 465 466 public void intoArray(byte[] a, int ax) { 467 forEach((i, a_) -> a[ax + i] = a_); 468 } 469 470 public void intoArray(byte[] a, int ax, Mask<Byte, S> m) { 471 forEach(m, (i, a_) -> a[ax + i] = a_); 472 } 473 474 public void intoArray(byte[] a, int ax, int[] indexMap, int mx) { 475 forEach((i, a_) -> a[ax + indexMap[mx + i]] = a_); 476 } 477 478 public void intoArray(byte[] a, int ax, Mask<Byte, S> m, int[] indexMap, int mx) { 479 forEach(m, (i, a_) -> a[ax + indexMap[mx + i]] = a_); 480 } 481 482 // Species 483 484 @Override 485 public abstract ByteSpecies<S> species(); 486 487 public static abstract class ByteSpecies<S extends Vector.Shape> extends Vector.Species<Byte, S> { 488 interface FOp { 489 byte apply(int i); 490 } 491 492 abstract ByteVector<S> op(FOp f); 493 494 abstract ByteVector<S> op(Mask<Byte, S> m, FOp f); 495 496 // Factories 497 498 @Override 499 public ByteVector<S> zero() { 500 return op(i -> 0); 501 } 502 503 public ByteVector<S> broadcast(byte e) { 504 return op(i -> e); 505 } 506 507 public ByteVector<S> single(byte e) { 508 return op(i -> i == 0 ? e : (byte) 0); 509 } 510 511 public ByteVector<S> random() { 512 ThreadLocalRandom r = ThreadLocalRandom.current(); 513 return op(i -> (byte) r.nextInt()); 514 } 515 516 public ByteVector<S> scalars(byte... es) { 517 return op(i -> es[i]); 518 } 519 520 public ByteVector<S> fromArray(byte[] a, int ax) { 521 return op(i -> a[ax + i]); 522 } 523 524 public ByteVector<S> fromArray(byte[] a, int ax, Mask<Byte, S> m) { 525 return op(m, i -> a[ax + i]); 526 } 527 528 public ByteVector<S> fromArray(byte[] a, int ax, int[] indexMap, int mx) { 529 return op(i -> a[ax + indexMap[mx + i]]); 530 } 531 532 public ByteVector<S> fromArray(byte[] a, int ax, Mask<Byte, S> m, int[] indexMap, int mx) { 533 return op(m, i -> a[ax + indexMap[mx + i]]); 534 } 535 536 @Override 537 public ByteVector<S> fromByteArray(byte[] a, int ix) { 538 ByteBuffer bb = ByteBuffer.wrap(a, ix, a.length - ix); 539 return fromByteBuffer(bb); 540 } 541 542 @Override 543 public ByteVector<S> fromByteArray(byte[] a, int ix, Mask<Byte, S> m) { 544 ByteBuffer bb = ByteBuffer.wrap(a, ix, a.length - ix); 545 return fromByteBuffer(bb, m); 546 } 547 548 @Override 549 public ByteVector<S> fromByteBuffer(ByteBuffer bb) { 550 ByteBuffer fb = bb; 551 return op(i -> fb.get()); 552 } 553 554 @Override 555 public ByteVector<S> fromByteBuffer(ByteBuffer bb, Mask<Byte, S> m) { 556 ByteBuffer fb = bb; 557 return op(i -> { 558 if(m.getElement(i)) 559 return fb.get(); 560 else { 561 fb.position(fb.position() + 1); 562 return (byte) 0; 563 } 564 }); 565 } 566 567 @Override 568 public ByteVector<S> fromByteBuffer(ByteBuffer bb, int ix) { 569 bb = bb.duplicate().position(ix); 570 ByteBuffer fb = bb; 571 return op(i -> fb.get(i)); 572 } 573 574 @Override 575 public ByteVector<S> fromByteBuffer(ByteBuffer bb, int ix, Mask<Byte, S> m) { 576 bb = bb.duplicate().position(ix); 577 ByteBuffer fb = bb; 578 return op(m, i -> fb.get(i)); 579 } 580 581 @Override 582 @ForceInline 583 public <F, T extends Shape> ByteVector<S> reshape(Vector<F, T> o) { 584 int blen = Math.max(o.species().bitSize(), bitSize()) / Byte.SIZE; 585 ByteBuffer bb = ByteBuffer.allocate(blen).order(ByteOrder.nativeOrder()); 586 o.intoByteBuffer(bb, 0); 587 return fromByteBuffer(bb, 0); 588 } 589 590 @Override 591 @ForceInline 592 public <F> ByteVector<S> rebracket(Vector<F, S> o) { 593 return reshape(o); 594 } 595 596 @Override 597 @ForceInline 598 public <T extends Shape> ByteVector<S> resize(Vector<Byte, T> o) { 599 return reshape(o); 600 } 601 602 @Override 603 @SuppressWarnings("unchecked") 604 public <F, T extends Shape> ByteVector<S> cast(Vector<F, T> v) { 605 // Allocate array of required size 606 byte[] a = new byte[length()]; 607 608 Class<?> vtype = v.species().elementType(); 609 int limit = Math.min(v.species().length(), length()); 610 if (vtype == Byte.class) { 611 ByteVector<T> tv = (ByteVector<T>)v; 612 for (int i = 0; i < limit; i++) { 613 a[i] = (byte) tv.get(i); 614 } 615 } else if (vtype == Short.class) { 616 ShortVector<T> tv = (ShortVector<T>)v; 617 for (int i = 0; i < limit; i++) { 618 a[i] = (byte) tv.get(i); 619 } 620 } else if (vtype == Integer.class) { 621 IntVector<T> tv = (IntVector<T>)v; 622 for (int i = 0; i < limit; i++) { 623 a[i] = (byte) tv.get(i); 624 } 625 } else if (vtype == Long.class){ 626 LongVector<T> tv = (LongVector<T>)v; 627 for (int i = 0; i < limit; i++) { 628 a[i] = (byte) tv.get(i); 629 } 630 } else if (vtype == Float.class){ 631 FloatVector<T> tv = (FloatVector<T>)v; 632 for (int i = 0; i < limit; i++) { 633 a[i] = (byte) tv.get(i); 634 } 635 } else if (vtype == Double.class){ 636 DoubleVector<T> tv = (DoubleVector<T>)v; 637 for (int i = 0; i < limit; i++) { 638 a[i] = (byte) tv.get(i); 639 } 640 } else { 641 throw new UnsupportedOperationException("Bad lane type for casting."); 642 } 643 644 return scalars(a); 645 } 646 647 } 648 }