1 /*
   2  * Copyright (c) 2009, 2010, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /* @test
  25  * @bug 4927640
  26  * @summary Tests the SCTP protocol implementation
  27  * @author chegar
  28  */
  29 
  30 import java.net.InetSocketAddress;
  31 import java.net.SocketAddress;
  32 import java.io.IOException;
  33 import java.util.Set;
  34 import java.util.Iterator;
  35 import java.util.concurrent.CountDownLatch;
  36 import java.util.concurrent.TimeUnit;
  37 import java.nio.ByteBuffer;
  38 import com.sun.nio.sctp.Association;
  39 import com.sun.nio.sctp.InvalidStreamException;
  40 import com.sun.nio.sctp.MessageInfo;
  41 import com.sun.nio.sctp.SctpMultiChannel;
  42 import static java.lang.System.out;
  43 import static java.lang.System.err;
  44 
  45 public class Send {
  46     /* Latches used to synchronize between the client and server so that
  47      * connections without any IO may not be closed without being accepted */
  48     final CountDownLatch clientFinishedLatch = new CountDownLatch(1);
  49     final CountDownLatch serverFinishedLatch = new CountDownLatch(1);
  50 
  51     void test(String[] args) {
  52         SocketAddress address = null;
  53         Server server = null;
  54 
  55         if (!Util.isSCTPSupported()) {
  56             out.println("SCTP protocol is not supported");
  57             out.println("Test cannot be run");
  58             return;
  59         }
  60 
  61         if (args.length == 2) {
  62             /* requested to connecct to a specific address */
  63             try {
  64                 int port = Integer.valueOf(args[1]);
  65                 address = new InetSocketAddress(args[0], port);
  66             } catch (NumberFormatException nfe) {
  67                 err.println(nfe);
  68             }
  69         } else {
  70             /* start server on local machine, default */
  71             try {
  72                 server = new Server();
  73                 server.start();
  74                 address = server.address();
  75                 debug("Server started and listening on " + address);
  76             } catch (IOException ioe) {
  77                 ioe.printStackTrace();
  78                 return;
  79             }
  80         }
  81 
  82         doTest(address);
  83     }
  84 
  85     void doTest(SocketAddress peerAddress) {
  86         SctpMultiChannel channel = null;
  87         ByteBuffer buffer = ByteBuffer.allocate(Util.LARGE_BUFFER);
  88         MessageInfo info = MessageInfo.createOutgoing(null, 0);
  89 
  90         try {
  91             channel = SctpMultiChannel.open();
  92 
  93             /* TEST 1: send small message */
  94             int streamNumber = 0;
  95             debug("sending to " + peerAddress + " on stream number: " + streamNumber);
  96             info = MessageInfo.createOutgoing(peerAddress, streamNumber);
  97             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
  98             buffer.flip();
  99             int position = buffer.position();
 100             int remaining = buffer.remaining();
 101 
 102             debug("sending small message: " + buffer);
 103             int sent = channel.send(buffer, info);
 104 
 105             check(sent == remaining, "sent should be equal to remaining");
 106             check(buffer.position() == (position + sent),
 107                     "buffers position should have been incremented by sent");
 108 
 109             /* TEST 2: receive the echoed message */
 110             buffer.clear();
 111             info = channel.receive(buffer, null, null);
 112             buffer.flip();
 113             check(info != null, "info is null");
 114             check(info.streamNumber() == streamNumber,
 115                     "message not sent on the correct stream");
 116             check(info.bytes() == Util.SMALL_MESSAGE.getBytes("ISO-8859-1").
 117                   length, "bytes received not equal to message length");
 118             check(info.bytes() == buffer.remaining(), "bytes != remaining");
 119             check(Util.compare(buffer, Util.SMALL_MESSAGE),
 120               "received message not the same as sent message");
 121 
 122 
 123             /* TEST 3: send large message */
 124             Set<Association> assocs = channel.associations();
 125             check(assocs.size() == 1, "there should be only one association");
 126             Iterator<Association> it = assocs.iterator();
 127             check(it.hasNext());
 128             Association assoc = it.next();
 129             streamNumber = assoc.maxOutboundStreams() - 1;
 130 
 131             debug("sending on stream number: " + streamNumber);
 132             info = MessageInfo.createOutgoing(assoc, null, streamNumber);
 133             buffer.clear();
 134             buffer.put(Util.LARGE_MESSAGE.getBytes("ISO-8859-1"));
 135             buffer.flip();
 136             position = buffer.position();
 137             remaining = buffer.remaining();
 138 
 139             debug("sending large message: " + buffer);
 140             sent = channel.send(buffer, info);
 141 
 142             check(sent == remaining, "sent should be equal to remaining");
 143             check(buffer.position() == (position + sent),
 144                     "buffers position should have been incremented by sent");
 145 
 146             /* TEST 4: receive the echoed message */
 147             buffer.clear();
 148             info = channel.receive(buffer, null, null);
 149             buffer.flip();
 150             check(info != null, "info is null");
 151             check(info.streamNumber() == streamNumber,
 152                     "message not sent on the correct stream");
 153             check(info.bytes() == Util.LARGE_MESSAGE.getBytes("ISO-8859-1").
 154                   length, "bytes received not equal to message length");
 155             check(info.bytes() == buffer.remaining(), "bytes != remaining");
 156             check(Util.compare(buffer, Util.LARGE_MESSAGE),
 157               "received message not the same as sent message");
 158 
 159 
 160             /* TEST 5: InvalidStreamExcepton */
 161             streamNumber = assoc.maxOutboundStreams() + 1;
 162             info = MessageInfo.createOutgoing(assoc, null, streamNumber);
 163             buffer.clear();
 164             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
 165             buffer.flip();
 166             position = buffer.position();
 167             remaining = buffer.remaining();
 168 
 169             debug("sending on stream number: " + streamNumber);
 170             debug("sending small message: " + buffer);
 171             try {
 172                 sent = channel.send(buffer, info);
 173                 fail("should have thrown InvalidStreamExcepton");
 174             } catch (InvalidStreamException ise){
 175                 pass();
 176             } catch (IOException ioe) {
 177                 unexpected(ioe);
 178             }
 179             check(buffer.remaining() == remaining,
 180                     "remaining should not be changed");
 181             check(buffer.position() == position,
 182                     "buffers position should not be changed");
 183 
 184 
 185             /* TEST 5: getRemoteAddresses(Association) */
 186             channel.getRemoteAddresses(assoc);
 187 
 188             /* TEST 6: Send from heap buffer to force implementation to
 189              * substitute with a native buffer, then check that its position
 190              * is updated correctly */
 191             info = MessageInfo.createOutgoing(assoc, null, 0);
 192             buffer.clear();
 193             buffer.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
 194             buffer.flip();
 195             final int offset = 1;
 196             buffer.position(offset);
 197             remaining = buffer.remaining();
 198 
 199             try {
 200                 sent = channel.send(buffer, info);
 201 
 202                 check(sent == remaining, "sent should be equal to remaining");
 203                 check(buffer.position() == (offset + sent),
 204                         "buffers position should have been incremented by sent");
 205             } catch (IllegalArgumentException iae) {
 206                 fail(iae + ", Error updating buffers position");
 207             }
 208 
 209         } catch (IOException ioe) {
 210             unexpected(ioe);
 211         } finally {
 212             clientFinishedLatch.countDown();
 213             try { serverFinishedLatch.await(10L, TimeUnit.SECONDS); }
 214             catch (InterruptedException ie) { unexpected(ie); }
 215             if (channel != null) {
 216                 try { channel.close(); }
 217                 catch (IOException e) { unexpected (e);}
 218             }
 219         }
 220     }
 221 
 222     class Server implements Runnable
 223     {
 224         final InetSocketAddress serverAddr;
 225         private SctpMultiChannel serverChannel;
 226 
 227         public Server() throws IOException {
 228             serverChannel = SctpMultiChannel.open().bind(null);
 229             java.util.Set<SocketAddress> addrs = serverChannel.getAllLocalAddresses();
 230             if (addrs.isEmpty())
 231                 debug("addrs should not be empty");
 232 
 233             serverAddr = (InetSocketAddress) addrs.iterator().next();
 234         }
 235 
 236         public void start() {
 237             (new Thread(this, "Server-"  + serverAddr.getPort())).start();
 238         }
 239 
 240         public InetSocketAddress address() {
 241             return serverAddr;
 242         }
 243 
 244         @Override
 245         public void run() {
 246             ByteBuffer buffer = ByteBuffer.allocateDirect(Util.LARGE_BUFFER);
 247             try {
 248                 MessageInfo info;
 249 
 250                 /* receive a small message */
 251                 do {
 252                     info = serverChannel.receive(buffer, null, null);
 253                     if (info == null) {
 254                         fail("Server: unexpected null from receive");
 255                             return;
 256                     }
 257                 } while (!info.isComplete());
 258 
 259                 buffer.flip();
 260                 check(info != null, "info is null");
 261                 check(info.streamNumber() == 0,
 262                         "message not sent on the correct stream");
 263                 check(info.bytes() == Util.SMALL_MESSAGE.getBytes("ISO-8859-1").
 264                       length, "bytes received not equal to message length");
 265                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
 266                 check(Util.compare(buffer, Util.SMALL_MESSAGE),
 267                   "received message not the same as sent message");
 268 
 269                 check(info != null, "info is null");
 270                 Set<Association> assocs = serverChannel.associations();
 271                 check(assocs.size() == 1, "there should be only one association");
 272                 Iterator<Association> it = assocs.iterator();
 273                 check(it.hasNext());
 274                 Association assoc = it.next();
 275 
 276                 /* echo the message */
 277                 debug("Server: echoing first message");
 278                 buffer.flip();
 279                 int bytes = serverChannel.send(buffer, info);
 280                 debug("Server: sent " + bytes + "bytes");
 281 
 282                 /* receive a large message */
 283                 buffer.clear();
 284                 do {
 285                     info = serverChannel.receive(buffer, null, null);
 286                     if (info == null) {
 287                         fail("Server: unexpected null from receive");
 288                             return;
 289                     }
 290                 } while (!info.isComplete());
 291 
 292                 buffer.flip();
 293 
 294                 check(info.streamNumber() == assoc.maxInboundStreams() - 1,
 295                         "message not sent on the correct stream");
 296                 check(info.bytes() == Util.LARGE_MESSAGE.getBytes("ISO-8859-1").
 297                       length, "bytes received not equal to message length");
 298                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
 299                 check(Util.compare(buffer, Util.LARGE_MESSAGE),
 300                   "received message not the same as sent message");
 301 
 302                 /* echo the message */
 303                 debug("Server: echoing second message");
 304                 buffer.flip();
 305                 bytes = serverChannel.send(buffer, info);
 306                 debug("Server: sent " + bytes + "bytes");
 307 
 308                 /* TEST 6 */
 309                 ByteBuffer expected = ByteBuffer.allocate(Util.SMALL_BUFFER);
 310                 expected.put(Util.SMALL_MESSAGE.getBytes("ISO-8859-1"));
 311                 expected.flip();
 312                 final int offset = 1;
 313                 expected.position(offset);
 314                 buffer.clear();
 315                 do {
 316                     info = serverChannel.receive(buffer, null, null);
 317                     if (info == null) {
 318                         fail("Server: unexpected null from receive");
 319                         return;
 320                     }
 321                 } while (!info.isComplete());
 322 
 323                 buffer.flip();
 324                 check(info != null, "info is null");
 325                 check(info.streamNumber() == 0, "message not sent on the correct stream");
 326                 check(info.bytes() == expected.remaining(),
 327                     "bytes received not equal to message length");
 328                 check(info.bytes() == buffer.remaining(), "bytes != remaining");
 329                 check(expected.equals(buffer),
 330                     "received message not the same as sent message");
 331 
 332                 clientFinishedLatch.await(10L, TimeUnit.SECONDS);
 333                 serverFinishedLatch.countDown();
 334             } catch (IOException ioe) {
 335                 unexpected(ioe);
 336             } catch (InterruptedException ie) {
 337                 unexpected(ie);
 338             } finally {
 339                 try { if (serverChannel != null) serverChannel.close(); }
 340                 catch (IOException  unused) {}
 341             }
 342         }
 343     }
 344 
 345         //--------------------- Infrastructure ---------------------------
 346     boolean debug = true;
 347     volatile int passed = 0, failed = 0;
 348     void pass() {passed++;}
 349     void fail() {failed++; Thread.dumpStack();}
 350     void fail(String msg) {System.err.println(msg); fail();}
 351     void unexpected(Throwable t) {failed++; t.printStackTrace();}
 352     void check(boolean cond) {if (cond) pass(); else fail();}
 353     void check(boolean cond, String failMessage) {if (cond) pass(); else fail(failMessage);}
 354     void debug(String message) {if(debug) { System.out.println(message); }  }
 355     public static void main(String[] args) throws Throwable {
 356         Class<?> k = new Object(){}.getClass().getEnclosingClass();
 357         try {k.getMethod("instanceMain",String[].class)
 358                 .invoke( k.newInstance(), (Object) args);}
 359         catch (Throwable e) {throw e.getCause();}}
 360     public void instanceMain(String[] args) throws Throwable {
 361         try {test(args);} catch (Throwable t) {unexpected(t);}
 362         System.out.printf("%nPassed = %d, failed = %d%n%n", passed, failed);
 363         if (failed > 0) throw new AssertionError("Some tests failed");}
 364 
 365 }