1 /*
   2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8195160
  27  * @summary Test for basic functionality for RdmaSocketChannel
  28  *         and RdmaServerSocketChannel 
  29  * @requires (os.family == "linux")
  30  * @library .. /test/lib
  31  * @build RsocketTest
  32  * @run main/othervm BasicSocketChannelTest
  33  */
  34 
  35 import java.io.InputStream;
  36 import java.io.OutputStream;
  37 import java.net.InetAddress;
  38 import java.net.InetSocketAddress;
  39 import java.net.StandardProtocolFamily;
  40 import java.net.Socket;
  41 import java.nio.channels.ServerSocketChannel;
  42 import java.nio.channels.SocketChannel;
  43 import java.nio.ByteBuffer;
  44 import jdk.net.RdmaSockets;
  45 
  46 import jtreg.SkippedException;
  47 
  48 public class BasicSocketChannelTest implements Runnable {
  49     static ServerSocketChannel ssc;
  50     static SocketChannel sc1, sc2;
  51     static InetAddress iaddr;
  52     static int port = 0;
  53     static String message = "This is a message!";
  54     static int length = -1;
  55 
  56     public static void main(String args[]) throws Exception {
  57         if (!RsocketTest.isRsocketAvailable())
  58             throw new SkippedException("rsocket is not available");
  59 
  60         iaddr = InetAddress.getLocalHost();
  61         length = message.length();
  62         String result;
  63 
  64         //test SocketChannel and ServerSocketChannel
  65         try {
  66             ssc = RdmaSockets.openServerSocketChannel(
  67                 StandardProtocolFamily.INET);
  68             sc1 = RdmaSockets.openSocketChannel(
  69                 StandardProtocolFamily.INET);
  70 
  71             ssc.bind(new InetSocketAddress(iaddr, port));
  72 
  73             Thread t = new Thread(new BasicSocketChannelTest(), "Channel");
  74             t.start(); 
  75 
  76             sc2 = ssc.accept();
  77             ByteBuffer output = ByteBuffer.allocate(length);
  78             int outputNum = 0;
  79             while (outputNum < length) {
  80                 outputNum += sc2.read(output);
  81             }
  82 
  83             result = new String(output.array());
  84             if(!result.equals(message))
  85                 throw new RuntimeException("Test Failed!");
  86             sc2.shutdownInput();
  87             sc2.shutdownOutput();
  88         } catch (Exception e) {
  89             e.printStackTrace();
  90             throw new RuntimeException("Test Failed!");
  91         } finally {
  92             ssc.close();
  93             sc1.close();
  94             sc2.close();
  95         }
  96 
  97         //test SocketChannel.socket() and ServerSocketChannel.socket()
  98         try {
  99             ssc = RdmaSockets.openServerSocketChannel(StandardProtocolFamily.INET);
 100             sc1 = RdmaSockets.openSocketChannel(StandardProtocolFamily.INET);
 101 
 102             ssc.socket().bind(new InetSocketAddress(iaddr, port));
 103 
 104             Thread t = new Thread(new BasicSocketChannelTest(), "Socket");
 105             t.start();
 106 
 107             sc2 = ssc.accept();
 108             Socket conn = sc2.socket();
 109             InputStream is = conn.getInputStream();
 110 
 111             int num = 0;
 112             byte[] buf = new byte[length];
 113             while (num < length) {
 114                 int l = is.read(buf);
 115                 num += l;
 116             }
 117 
 118             result = new String(buf);
 119             if(!result.equals(message))
 120                 throw new RuntimeException("Test Failed!");
 121             conn.shutdownInput();
 122             conn.shutdownOutput();
 123             if (!conn.isInputShutdown() || !conn.isOutputShutdown())
 124                 throw new RuntimeException("Test Failed!");
 125         } catch (Exception e) {
 126             e.printStackTrace();
 127             throw new RuntimeException("Test Failed!");
 128         } finally {
 129             ssc.close();
 130             sc1.close();
 131             sc2.close();
 132         }
 133     }
 134 
 135     public void run() {
 136         int port = ssc.socket().getLocalPort();
 137         byte[] arr = new byte[length];
 138         arr = message.getBytes();
 139 
 140         try {
 141             if (Thread.currentThread().getName().startsWith("Channel")) {
 142                 sc1.connect(new InetSocketAddress(iaddr, port));
 143                 if (!sc1.isConnected())
 144                     throw new RuntimeException("Test Failed!");
 145                 ByteBuffer input = ByteBuffer.allocate(length);
 146                 input.put(arr);
 147                 input.flip();
 148                 int inputNum = 0;
 149                 while (inputNum < length) {
 150                     inputNum += sc1.write(input);
 151                 }
 152                 sc1.shutdownInput();
 153                 sc1.shutdownOutput();
 154             } else if (Thread.currentThread().getName().startsWith("Socket")) {
 155                 Socket client = sc1.socket();
 156                 client.connect(new InetSocketAddress(iaddr, port));
 157                 if (!client.isConnected())
 158                     throw new RuntimeException("Test Failed!");
 159                 OutputStream os = client.getOutputStream();
 160                 os.write(arr);
 161                 client.shutdownInput();
 162                 client.shutdownOutput();
 163                 if (!client.isInputShutdown() || !client.isOutputShutdown())
 164                     throw new RuntimeException("Test Failed!");
 165             }
 166         } catch (Exception e) {
 167             e.printStackTrace();
 168             throw new RuntimeException("Test Failed!");
 169         }
 170     }
 171 }