1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /* @test
  25  * @bug 8195160
  26  * @summary Test RdmaSelector with RdmaServerSocketChannels
  27  * @requires (os.family == "linux")
  28  * @library .. /test/lib /test/jdk/java/nio/channels
  29  * @build RsocketTest
  30  * @run main/othervm BasicAccept
  31  */
  32 
  33 import java.io.IOException;
  34 import java.net.InetAddress;
  35 import java.net.InetSocketAddress;
  36 import java.net.StandardProtocolFamily;
  37 import java.nio.ByteBuffer;
  38 import java.nio.channels.Selector;
  39 import java.nio.channels.SelectionKey;
  40 import java.nio.channels.ServerSocketChannel;
  41 import java.nio.channels.SocketChannel;
  42 import java.nio.channels.spi.SelectorProvider;
  43 import java.util.Iterator;
  44 import java.util.Set;
  45 import jdk.net.RdmaSockets;
  46 
  47 import jtreg.SkippedException;
  48 
  49 public class SelectorTest {
  50     private static List clientList = new LinkedList();
  51     private static Random rnd = new Random();
  52     public static int NUM_CLIENTS = 5;
  53     public static int TEST_PORT = 31452;
  54     static PrintStream log = System.err;
  55     private static int FINISH_TIME = 30000;
  56 
  57     /*
  58      * Usage note
  59      *
  60      * java SelectorTest [server] [client <host>] [<port>]
  61      *
  62      * No arguments runs both client and server in separate threads
  63      * using the default port of 31452.
  64      *
  65      * client runs the client on this machine and connects to server
  66      * at the given IP address.
  67      *
  68      * server runs the server on localhost.
  69      */
  70     public static void main(String[] args) throws Exception {
  71         if (!RsocketTest.isRsocketAvailable())
  72             throw new SkippedException("rsocket is not available");
  73 
  74         if (args.length == 0) {
  75             Server server = new Server(0);
  76             server.start();
  77             try {
  78                 Thread.sleep(1000);
  79             } catch (InterruptedException e) { }
  80             InetSocketAddress isa
  81                 = new InetSocketAddress(InetAddress.getLocalHost(), server.port());
  82             Client client = new Client(isa);
  83             client.start();
  84             if ((server.finish(FINISH_TIME) & client.finish(FINISH_TIME)) == 0)
  85                 throw new Exception("Failure");
  86             log.println();
  87 
  88         } else if (args[0].equals("server")) {
  89 
  90             if (args.length > 1)
  91                 TEST_PORT = Integer.parseInt(args[1]);
  92             Server server = new Server(TEST_PORT);
  93             server.start();
  94             if (server.finish(FINISH_TIME) == 0)
  95                 throw new Exception("Failure");
  96             log.println();
  97 
  98         } else if (args[0].equals("client")) {
  99 
 100             if (args.length < 2) {
 101                 log.println("No host specified: terminating.");
 102                 return;
 103             }
 104             String ip = args[1];
 105             if (args.length > 2)
 106                 TEST_PORT = Integer.parseInt(args[2]);
 107             InetAddress ia = InetAddress.getByName(ip);
 108             InetSocketAddress isa = new InetSocketAddress(ia, TEST_PORT);
 109             Client client = new Client(isa);
 110             client.start();
 111             if (client.finish(FINISH_TIME) == 0)
 112                 throw new Exception("Failure");
 113             log.println();
 114 
 115         } else {
 116             System.out.println("Usage note:");
 117             System.out.println("java SelectorTest [server] [client <host>] [<port>]");
 118             System.out.println("No arguments runs both client and server in separate threads using the default port of 31452.");
 119             System.out.println("client runs the client on this machine and connects to the server specified.");
 120             System.out.println("server runs the server on localhost.");
 121         }
 122     }
 123 
 124     static class Client extends TestThread {
 125         InetSocketAddress isa;
 126         Client(InetSocketAddress isa) {
 127             super("Client", SelectorTest.log);
 128             this.isa = isa;
 129         }
 130 
 131         public void go() throws Exception {
 132             log.println("starting client...");
 133             for (int i=0; i<NUM_CLIENTS; i++)
 134                 clientList.add(new RemoteEntity(i, isa, log));
 135 
 136             Collections.shuffle(clientList);
 137 
 138             log.println("created "+NUM_CLIENTS+" clients");
 139             do {
 140                 for (Iterator i = clientList.iterator(); i.hasNext(); ) {
 141                     RemoteEntity re = (RemoteEntity) i.next();
 142                     if (re.cycle()) {
 143                         i.remove();
 144                     }
 145                 }
 146                 Collections.shuffle(clientList);
 147             } while (clientList.size() > 0);
 148         }
 149     }
 150 
 151     static class Server extends TestThread {
 152         private final ServerSocketChannel ssc;
 153         private List socketList = new ArrayList();
 154         private ServerSocket ss;
 155         private int connectionsAccepted = 0;
 156         private Selector pollSelector;
 157         private Selector acceptSelector;
 158         private Set pkeys;
 159         private Set pskeys;
 160 
 161         Server(int port) throws IOException {
 162             super("Server", SelectorTest.log);
 163             this.ssc = RdmaSockets.openServerSocketChannel(
 164                 StandardProtocolFamily.INET);
 165             ssc.bind(new InetSocketAddress(InetAddress.getLocalHost(), port));
 166         }
 167 
 168         int port() {
 169             return ssc.socket().getLocalPort();
 170         }
 171 
 172         public void go() throws Exception {
 173             log.println("starting server...");
 174             acceptSelector = RdmaSockets.openSelector();
 175             pollSelector = RdmaSockets.openSelector();
 176             pkeys = pollSelector.keys();
 177             pskeys = pollSelector.selectedKeys();
 178             Set readyKeys = acceptSelector.selectedKeys();
 179             RequestHandler rh = new RequestHandler(pollSelector, log);
 180             Thread requestThread = new Thread(rh);
 181 
 182             requestThread.start();
 183 
 184             ssc.configureBlocking(false);
 185             SelectionKey acceptKey = ssc.register(acceptSelector,
 186                                                   SelectionKey.OP_ACCEPT);
 187             while(connectionsAccepted < SelectorTest.NUM_CLIENTS) {
 188                 int keysAdded = acceptSelector.select(100);
 189                 if (keysAdded > 0) {
 190                     Iterator i = readyKeys.iterator();
 191                     while(i.hasNext()) {
 192                         SelectionKey sk = (SelectionKey)i.next();
 193                         i.remove();
 194                         ServerSocketChannel nextReady =
 195                             (ServerSocketChannel)sk.channel();
 196                         SocketChannel sc = nextReady.accept();
 197                         connectionsAccepted++;
 198                         if (sc != null) {
 199                             sc.configureBlocking(false);
 200                             synchronized (pkeys) {
 201                                sc.register(pollSelector, SelectionKey.OP_READ);
 202                             }
 203                         } else {
 204                             throw new RuntimeException(
 205                                 "Socket does not support Channels");
 206                         }
 207                     }
 208                 }
 209             }
 210             acceptKey.cancel();
 211             requestThread.join();
 212             acceptSelector.close();
 213             pollSelector.close();
 214         }
 215     }
 216 }
 217 
 218 class RemoteEntity {
 219     private static Random rnd = new Random();
 220     int id;
 221     ByteBuffer data;
 222     int dataWrittenIndex;
 223     int totalDataLength;
 224     boolean initiated = false;
 225     boolean connected = false;
 226     boolean written = false;
 227     boolean acked = false;
 228     boolean closed = false;
 229     private SocketChannel sc;
 230     ByteBuffer ackBuffer;
 231     PrintStream log;
 232     InetSocketAddress server;
 233 
 234     RemoteEntity(int id, InetSocketAddress server, PrintStream log)
 235         throws Exception
 236     {
 237         int connectFailures = 0;
 238         this.id = id;
 239         this.log = log;
 240         this.server = server;
 241 
 242         sc = RdmaSockets.openSocketChannel(StandardProtocolFamily.INET);
 243         sc.configureBlocking(false);
 244 
 245         // Prepare the data buffer to write out from this entity
 246         // Let's use both slow and fast buffers
 247         if (rnd.nextBoolean())
 248             data = ByteBuffer.allocateDirect(100);
 249         else
 250             data = ByteBuffer.allocate(100);
 251         String number = Integer.toString(id);
 252         if (number.length() == 1)
 253             number = "0"+number;
 254         String source = "Testing from " + number;
 255         data.put(source.getBytes("8859_1"));
 256         data.flip();
 257         totalDataLength = source.length();
 258 
 259         // Allocate an ack buffer
 260         ackBuffer = ByteBuffer.allocateDirect(10);
 261     }
 262 
 263     private void reset() throws Exception {
 264         sc.close();
 265         sc = RdmaSockets.openSocketChannel(StandardProtocolFamily.INET);
 266         sc.configureBlocking(false);
 267     }
 268 
 269     private void connect() throws Exception {
 270         try {
 271             connected = sc.connect(server);
 272             initiated = true;
 273         }  catch (ConnectException e) {
 274             initiated = false;
 275             reset();
 276         }
 277     }
 278 
 279     private void finishConnect() throws Exception {
 280         try {
 281             connected = sc.finishConnect();
 282         }  catch (IOException e) {
 283             initiated = false;
 284             reset();
 285         }
 286     }
 287 
 288     int id() {
 289         return id;
 290     }
 291 
 292     boolean cycle() throws Exception {
 293         if (!initiated)
 294             connect();
 295         else if (!connected)
 296             finishConnect();
 297         else if (!written)
 298             writeCycle();
 299         else if (!acked)
 300             ackCycle();
 301         else if (!closed)
 302             close();
 303         return closed;
 304     }
 305 
 306     private void ackCycle() throws Exception {
 307         //log.println("acking from "+id);
 308         int bytesRead = sc.read(ackBuffer);
 309         if (bytesRead > 0) {
 310             acked = true;
 311         }
 312     }
 313 
 314     private void close() throws Exception {
 315         sc.close();
 316         closed = true;
 317     }
 318 
 319     private void writeCycle() throws Exception {
 320         log.println("writing from "+id);
 321         int numBytesToWrite = rnd.nextInt(10)+1;
 322         int newWriteTarget = dataWrittenIndex + numBytesToWrite;
 323         if (newWriteTarget > totalDataLength)
 324             newWriteTarget = totalDataLength;
 325         data.limit(newWriteTarget);
 326         int bytesWritten = sc.write(data);
 327         if (bytesWritten > 0)
 328             dataWrittenIndex += bytesWritten;
 329         if (dataWrittenIndex == totalDataLength) {
 330             written = true;
 331             sc.socket().shutdownOutput();
 332         }
 333     }
 334 
 335 }
 336 
 337 
 338 class RequestHandler implements Runnable {
 339     private static Random rnd = new Random();
 340     private Selector selector;
 341     private int connectionsHandled = 0;
 342     private HashMap dataBin = new HashMap();
 343     PrintStream log;
 344 
 345     public RequestHandler(Selector selector, PrintStream log) {
 346         this.selector = selector;
 347         this.log = log;
 348     }
 349 
 350     public void run() {
 351         log.println("starting request handler...");
 352         int connectionsAccepted = 0;
 353 
 354         Set nKeys = selector.keys();
 355         Set readyKeys = selector.selectedKeys();
 356 
 357         try {
 358             while(connectionsHandled < SelectorTest.NUM_CLIENTS) {
 359                 int numKeys = selector.select(100);
 360 
 361                 // Process channels with data
 362                 synchronized (nKeys) {
 363                     if (readyKeys.size() > 0) {
 364                         Iterator i = readyKeys.iterator();
 365                         while(i.hasNext()) {
 366                             SelectionKey sk = (SelectionKey)i.next();
 367                             i.remove();
 368                             SocketChannel sc = (SocketChannel)sk.channel();
 369                             if (sc.isOpen())
 370                                 read(sk, sc);
 371                         }
 372                     }
 373                 }
 374 
 375                 // Give other threads a chance to run
 376                 if (numKeys == 0) {
 377                     try {
 378                         Thread.sleep(1);
 379                     } catch (Exception x) {}
 380                 }
 381             }
 382         } catch (Exception e) {
 383             log.println("Unexpected error 1: "+e);
 384             e.printStackTrace();
 385         }
 386     }
 387 
 388     private void read(SelectionKey sk, SocketChannel sc) throws Exception {
 389         ByteBuffer bin = (ByteBuffer)dataBin.get(sc);
 390         if (bin == null) {
 391             if (rnd.nextBoolean())
 392                 bin = ByteBuffer.allocateDirect(100);
 393             else
 394                 bin = ByteBuffer.allocate(100);
 395             dataBin.put(sc, bin);
 396         }
 397 
 398         int bytesRead = 0;
 399         do {
 400             bytesRead = sc.read(bin);
 401         } while(bytesRead > 0);
 402 
 403         if (bytesRead == -1) {
 404             sk.interestOps(0);
 405             bin.flip();
 406             int size = bin.limit();
 407             byte[] data = new byte[size];
 408             for(int j=0; j<size; j++)
 409                 data[j] = bin.get();
 410             String message = new String(data, "8859_1");
 411             connectionsHandled++;
 412             acknowledge(sc);
 413             log.println("Received >>>"+message + "<<<");
 414             log.println("Handled: "+connectionsHandled);
 415         }
 416     }
 417 
 418     private void acknowledge(SocketChannel sc) throws Exception {
 419             ByteBuffer ackBuffer = ByteBuffer.allocateDirect(10);
 420             String s = "ack";
 421             ackBuffer.put(s.getBytes("8859_1"));
 422             ackBuffer.flip();
 423             int bytesWritten = 0;
 424             while(bytesWritten == 0) {
 425                 bytesWritten += sc.write(ackBuffer);
 426             }
 427             sc.close();
 428     }
 429 }