1 /*
   2  * Copyright (c) 2000, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /* @test
  25  * @bug 8195160
  26  * @summary Test RdmaSelector with RdmaServerSocketChannels
  27  * @requires (os.family == "linux")
  28  * @library .. /test/lib
  29  * @build RsocketTest
  30  * @run main/othervm -Djava.net.preferIPv4Stack=true BasicAccept
  31  */
  32 
  33 import java.io.*;
  34 import java.net.*;
  35 import java.nio.*;
  36 import java.nio.channels.*;
  37 import java.nio.channels.spi.SelectorProvider;
  38 import java.util.*;
  39 import jdk.net.Sockets;
  40 
  41 public class SelectorTest {
  42     private static List clientList = new LinkedList();
  43     private static Random rnd = new Random();
  44     public static int NUM_CLIENTS = 5;
  45     public static int TEST_PORT = 31452;
  46     static PrintStream log = System.err;
  47     private static int FINISH_TIME = 30000;
  48 
  49     /*
  50      * Usage note
  51      *
  52      * java SelectorTest [server] [client <host>] [<port>]
  53      *
  54      * No arguments runs both client and server in separate threads
  55      * using the default port of 31452.
  56      *
  57      * client runs the client on this machine and connects to server
  58      * at the given IP address.
  59      *
  60      * server runs the server on localhost.
  61      */
  62     public static void main(String[] args) throws Exception {
  63         if (!RsocketTest.isRsocketAvailable())
  64             return;
  65 
  66         if (args.length == 0) {
  67             Server server = new Server(0);
  68             server.start();
  69             try {
  70                 Thread.sleep(1000);
  71             } catch (InterruptedException e) { }
  72             InetSocketAddress isa
  73                 = new InetSocketAddress(InetAddress.getLocalHost(), server.port());
  74             Client client = new Client(isa);
  75             client.start();
  76             if ((server.finish(FINISH_TIME) & client.finish(FINISH_TIME)) == 0)
  77                 throw new Exception("Failure");
  78             log.println();
  79 
  80         } else if (args[0].equals("server")) {
  81 
  82             if (args.length > 1)
  83                 TEST_PORT = Integer.parseInt(args[1]);
  84             Server server = new Server(TEST_PORT);
  85             server.start();
  86             if (server.finish(FINISH_TIME) == 0)
  87                 throw new Exception("Failure");
  88             log.println();
  89 
  90         } else if (args[0].equals("client")) {
  91 
  92             if (args.length < 2) {
  93                 log.println("No host specified: terminating.");
  94                 return;
  95             }
  96             String ip = args[1];
  97             if (args.length > 2)
  98                 TEST_PORT = Integer.parseInt(args[2]);
  99             InetAddress ia = InetAddress.getByName(ip);
 100             InetSocketAddress isa = new InetSocketAddress(ia, TEST_PORT);
 101             Client client = new Client(isa);
 102             client.start();
 103             if (client.finish(FINISH_TIME) == 0)
 104                 throw new Exception("Failure");
 105             log.println();
 106 
 107         } else {
 108             System.out.println("Usage note:");
 109             System.out.println("java SelectorTest [server] [client <host>] [<port>]");
 110             System.out.println("No arguments runs both client and server in separate threads using the default port of 31452.");
 111             System.out.println("client runs the client on this machine and connects to the server specified.");
 112             System.out.println("server runs the server on localhost.");
 113         }
 114     }
 115 
 116     static class Client extends TestThread {
 117         InetSocketAddress isa;
 118         Client(InetSocketAddress isa) {
 119             super("Client", SelectorTest.log);
 120             this.isa = isa;
 121         }
 122 
 123         public void go() throws Exception {
 124             log.println("starting client...");
 125             for (int i=0; i<NUM_CLIENTS; i++)
 126                 clientList.add(new RemoteEntity(i, isa, log));
 127 
 128             Collections.shuffle(clientList);
 129 
 130             log.println("created "+NUM_CLIENTS+" clients");
 131             do {
 132                 for (Iterator i = clientList.iterator(); i.hasNext(); ) {
 133                     RemoteEntity re = (RemoteEntity) i.next();
 134                     if (re.cycle()) {
 135                         i.remove();
 136                     }
 137                 }
 138                 Collections.shuffle(clientList);
 139             } while (clientList.size() > 0);
 140         }
 141     }
 142 
 143     static class Server extends TestThread {
 144         private final ServerSocketChannel ssc;
 145         private List socketList = new ArrayList();
 146         private ServerSocket ss;
 147         private int connectionsAccepted = 0;
 148         private Selector pollSelector;
 149         private Selector acceptSelector;
 150         private Set pkeys;
 151         private Set pskeys;
 152 
 153         Server(int port) throws IOException {
 154             super("Server", SelectorTest.log);
 155             this.ssc = Sockets.openRdmaServerSocketChannel();
 156             ssc.bind(new InetSocketAddress(InetAddress.getLocalHost(), port));
 157         }
 158 
 159         int port() {
 160             return ssc.socket().getLocalPort();
 161         }
 162 
 163         public void go() throws Exception {
 164             log.println("starting server...");
 165             acceptSelector = Sockets.openRdmaSelector();
 166             pollSelector = Sockets.openRdmaSelector();
 167             pkeys = pollSelector.keys();
 168             pskeys = pollSelector.selectedKeys();
 169             Set readyKeys = acceptSelector.selectedKeys();
 170             RequestHandler rh = new RequestHandler(pollSelector, log);
 171             Thread requestThread = new Thread(rh);
 172 
 173             requestThread.start();
 174 
 175             ssc.configureBlocking(false);
 176             SelectionKey acceptKey = ssc.register(acceptSelector,
 177                                                   SelectionKey.OP_ACCEPT);
 178             while(connectionsAccepted < SelectorTest.NUM_CLIENTS) {
 179                 int keysAdded = acceptSelector.select(100);
 180                 if (keysAdded > 0) {
 181                     Iterator i = readyKeys.iterator();
 182                     while(i.hasNext()) {
 183                         SelectionKey sk = (SelectionKey)i.next();
 184                         i.remove();
 185                         ServerSocketChannel nextReady =
 186                             (ServerSocketChannel)sk.channel();
 187                         SocketChannel sc = nextReady.accept();
 188                         connectionsAccepted++;
 189                         if (sc != null) {
 190                             sc.configureBlocking(false);
 191                             synchronized (pkeys) {
 192                                sc.register(pollSelector, SelectionKey.OP_READ);
 193                             }
 194                         } else {
 195                             throw new RuntimeException(
 196                                 "Socket does not support Channels");
 197                         }
 198                     }
 199                 }
 200             }
 201             acceptKey.cancel();
 202             requestThread.join();
 203             acceptSelector.close();
 204             pollSelector.close();
 205         }
 206     }
 207 }
 208 
 209 class RemoteEntity {
 210     private static Random rnd = new Random();
 211     int id;
 212     ByteBuffer data;
 213     int dataWrittenIndex;
 214     int totalDataLength;
 215     boolean initiated = false;
 216     boolean connected = false;
 217     boolean written = false;
 218     boolean acked = false;
 219     boolean closed = false;
 220     private SocketChannel sc;
 221     ByteBuffer ackBuffer;
 222     PrintStream log;
 223     InetSocketAddress server;
 224 
 225     RemoteEntity(int id, InetSocketAddress server, PrintStream log)
 226         throws Exception
 227     {
 228         int connectFailures = 0;
 229         this.id = id;
 230         this.log = log;
 231         this.server = server;
 232 
 233         sc = Sockets.openRdmaSocketChannel();
 234         sc.configureBlocking(false);
 235 
 236         // Prepare the data buffer to write out from this entity
 237         // Let's use both slow and fast buffers
 238         if (rnd.nextBoolean())
 239             data = ByteBuffer.allocateDirect(100);
 240         else
 241             data = ByteBuffer.allocate(100);
 242         String number = Integer.toString(id);
 243         if (number.length() == 1)
 244             number = "0"+number;
 245         String source = "Testing from " + number;
 246         data.put(source.getBytes("8859_1"));
 247         data.flip();
 248         totalDataLength = source.length();
 249 
 250         // Allocate an ack buffer
 251         ackBuffer = ByteBuffer.allocateDirect(10);
 252     }
 253 
 254     private void reset() throws Exception {
 255         sc.close();
 256         sc = Sockets.openRdmaSocketChannel();
 257         sc.configureBlocking(false);
 258     }
 259 
 260     private void connect() throws Exception {
 261         try {
 262             connected = sc.connect(server);
 263             initiated = true;
 264         }  catch (ConnectException e) {
 265             initiated = false;
 266             reset();
 267         }
 268     }
 269 
 270     private void finishConnect() throws Exception {
 271         try {
 272             connected = sc.finishConnect();
 273         }  catch (IOException e) {
 274             initiated = false;
 275             reset();
 276         }
 277     }
 278 
 279     int id() {
 280         return id;
 281     }
 282 
 283     boolean cycle() throws Exception {
 284         if (!initiated)
 285             connect();
 286         else if (!connected)
 287             finishConnect();
 288         else if (!written)
 289             writeCycle();
 290         else if (!acked)
 291             ackCycle();
 292         else if (!closed)
 293             close();
 294         return closed;
 295     }
 296 
 297     private void ackCycle() throws Exception {
 298         //log.println("acking from "+id);
 299         int bytesRead = sc.read(ackBuffer);
 300         if (bytesRead > 0) {
 301             acked = true;
 302         }
 303     }
 304 
 305     private void close() throws Exception {
 306         sc.close();
 307         closed = true;
 308     }
 309 
 310     private void writeCycle() throws Exception {
 311         log.println("writing from "+id);
 312         int numBytesToWrite = rnd.nextInt(10)+1;
 313         int newWriteTarget = dataWrittenIndex + numBytesToWrite;
 314         if (newWriteTarget > totalDataLength)
 315             newWriteTarget = totalDataLength;
 316         data.limit(newWriteTarget);
 317         int bytesWritten = sc.write(data);
 318         if (bytesWritten > 0)
 319             dataWrittenIndex += bytesWritten;
 320         if (dataWrittenIndex == totalDataLength) {
 321             written = true;
 322             sc.socket().shutdownOutput();
 323         }
 324     }
 325 
 326 }
 327 
 328 
 329 class RequestHandler implements Runnable {
 330     private static Random rnd = new Random();
 331     private Selector selector;
 332     private int connectionsHandled = 0;
 333     private HashMap dataBin = new HashMap();
 334     PrintStream log;
 335 
 336     public RequestHandler(Selector selector, PrintStream log) {
 337         this.selector = selector;
 338         this.log = log;
 339     }
 340 
 341     public void run() {
 342         log.println("starting request handler...");
 343         int connectionsAccepted = 0;
 344 
 345         Set nKeys = selector.keys();
 346         Set readyKeys = selector.selectedKeys();
 347 
 348         try {
 349             while(connectionsHandled < SelectorTest.NUM_CLIENTS) {
 350                 int numKeys = selector.select(100);
 351 
 352                 // Process channels with data
 353                 synchronized (nKeys) {
 354                     if (readyKeys.size() > 0) {
 355                         Iterator i = readyKeys.iterator();
 356                         while(i.hasNext()) {
 357                             SelectionKey sk = (SelectionKey)i.next();
 358                             i.remove();
 359                             SocketChannel sc = (SocketChannel)sk.channel();
 360                             if (sc.isOpen())
 361                                 read(sk, sc);
 362                         }
 363                     }
 364                 }
 365 
 366                 // Give other threads a chance to run
 367                 if (numKeys == 0) {
 368                     try {
 369                         Thread.sleep(1);
 370                     } catch (Exception x) {}
 371                 }
 372             }
 373         } catch (Exception e) {
 374             log.println("Unexpected error 1: "+e);
 375             e.printStackTrace();
 376         }
 377     }
 378 
 379     private void read(SelectionKey sk, SocketChannel sc) throws Exception {
 380         ByteBuffer bin = (ByteBuffer)dataBin.get(sc);
 381         if (bin == null) {
 382             if (rnd.nextBoolean())
 383                 bin = ByteBuffer.allocateDirect(100);
 384             else
 385                 bin = ByteBuffer.allocate(100);
 386             dataBin.put(sc, bin);
 387         }
 388 
 389         int bytesRead = 0;
 390         do {
 391             bytesRead = sc.read(bin);
 392         } while(bytesRead > 0);
 393 
 394         if (bytesRead == -1) {
 395             sk.interestOps(0);
 396             bin.flip();
 397             int size = bin.limit();
 398             byte[] data = new byte[size];
 399             for(int j=0; j<size; j++)
 400                 data[j] = bin.get();
 401             String message = new String(data, "8859_1");
 402             connectionsHandled++;
 403             acknowledge(sc);
 404             log.println("Received >>>"+message + "<<<");
 405             log.println("Handled: "+connectionsHandled);
 406         }
 407     }
 408 
 409     private void acknowledge(SocketChannel sc) throws Exception {
 410             ByteBuffer ackBuffer = ByteBuffer.allocateDirect(10);
 411             String s = "ack";
 412             ackBuffer.put(s.getBytes("8859_1"));
 413             ackBuffer.flip();
 414             int bytesWritten = 0;
 415             while(bytesWritten == 0) {
 416                 bytesWritten += sc.write(ackBuffer);
 417             }
 418             sc.close();
 419     }
 420 }