/* * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ /* * @test * @bug 8195160 * @summary Test for basic functionality for RdmaSocketChannel * and RdmaServerSocketChannel * @requires (os.family == "linux") * @library .. /test/lib * @build RsocketTest * @run main/othervm BasicSocketChannelTest */ import java.io.InputStream; import java.io.OutputStream; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.StandardProtocolFamily; import java.net.Socket; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.ByteBuffer; import jdk.net.RdmaSockets; import jtreg.SkippedException; public class BasicSocketChannelTest implements Runnable { static ServerSocketChannel ssc; static SocketChannel sc1, sc2; static InetAddress iaddr; static int port = 0; static String message = "This is a message!"; static int length = -1; public static void main(String args[]) throws Exception { if (!RsocketTest.isRsocketAvailable()) throw new SkippedException("rsocket is not available"); iaddr = InetAddress.getLocalHost(); length = message.length(); String result; //test SocketChannel and ServerSocketChannel try { ssc = RdmaSockets.openServerSocketChannel( StandardProtocolFamily.INET); sc1 = RdmaSockets.openSocketChannel( StandardProtocolFamily.INET); ssc.bind(new InetSocketAddress(iaddr, port)); Thread t = new Thread(new BasicSocketChannelTest(), "Channel"); t.start(); sc2 = ssc.accept(); ByteBuffer output = ByteBuffer.allocate(length); int outputNum = 0; while (outputNum < length) { outputNum += sc2.read(output); } result = new String(output.array()); if(!result.equals(message)) throw new RuntimeException("Test Failed!"); sc2.shutdownInput(); sc2.shutdownOutput(); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException("Test Failed!"); } finally { ssc.close(); sc1.close(); sc2.close(); } //test SocketChannel.socket() and ServerSocketChannel.socket() try { ssc = RdmaSockets.openServerSocketChannel(StandardProtocolFamily.INET); sc1 = RdmaSockets.openSocketChannel(StandardProtocolFamily.INET); ssc.socket().bind(new InetSocketAddress(iaddr, port)); Thread t = new Thread(new BasicSocketChannelTest(), "Socket"); t.start(); sc2 = ssc.accept(); Socket conn = sc2.socket(); InputStream is = conn.getInputStream(); int num = 0; byte[] buf = new byte[length]; while (num < length) { int l = is.read(buf); num += l; } result = new String(buf); if(!result.equals(message)) throw new RuntimeException("Test Failed!"); conn.shutdownInput(); conn.shutdownOutput(); if (!conn.isInputShutdown() || !conn.isOutputShutdown()) throw new RuntimeException("Test Failed!"); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException("Test Failed!"); } finally { ssc.close(); sc1.close(); sc2.close(); } } public void run() { int port = ssc.socket().getLocalPort(); byte[] arr = new byte[length]; arr = message.getBytes(); try { if (Thread.currentThread().getName().startsWith("Channel")) { sc1.connect(new InetSocketAddress(iaddr, port)); if (!sc1.isConnected()) throw new RuntimeException("Test Failed!"); ByteBuffer input = ByteBuffer.allocate(length); input.put(arr); input.flip(); int inputNum = 0; while (inputNum < length) { inputNum += sc1.write(input); } sc1.shutdownInput(); sc1.shutdownOutput(); } else if (Thread.currentThread().getName().startsWith("Socket")) { Socket client = sc1.socket(); client.connect(new InetSocketAddress(iaddr, port)); if (!client.isConnected()) throw new RuntimeException("Test Failed!"); OutputStream os = client.getOutputStream(); os.write(arr); client.shutdownInput(); client.shutdownOutput(); if (!client.isInputShutdown() || !client.isOutputShutdown()) throw new RuntimeException("Test Failed!"); } } catch (Exception e) { e.printStackTrace(); throw new RuntimeException("Test Failed!"); } } }