--- /dev/null 2018-11-27 03:18:47.532777276 -0500 +++ new/test/jdk/jdk/net/RdmaSockets/rsocket/Selector/SelectorTest.java 2018-11-30 08:33:20.394538646 -0500 @@ -0,0 +1,429 @@ +/* + * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* @test + * @bug 8195160 + * @summary Test RdmaSelector with RdmaServerSocketChannels + * @requires (os.family == "linux") + * @library .. /test/lib /test/jdk/java/nio/channels + * @build RsocketTest + * @run main/othervm BasicAccept + */ + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.ByteBuffer; +import java.nio.channels.Selector; +import java.nio.channels.SelectionKey; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.Iterator; +import java.util.Set; +import jdk.net.RdmaSockets; + +import jtreg.SkippedException; + +public class SelectorTest { + private static List clientList = new LinkedList(); + private static Random rnd = new Random(); + public static int NUM_CLIENTS = 5; + public static int TEST_PORT = 31452; + static PrintStream log = System.err; + private static int FINISH_TIME = 30000; + + /* + * Usage note + * + * java SelectorTest [server] [client ] [] + * + * No arguments runs both client and server in separate threads + * using the default port of 31452. + * + * client runs the client on this machine and connects to server + * at the given IP address. + * + * server runs the server on localhost. + */ + public static void main(String[] args) throws Exception { + if (!RsocketTest.isRsocketAvailable()) + throw new SkippedException("rsocket is not available"); + + if (args.length == 0) { + Server server = new Server(0); + server.start(); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { } + InetSocketAddress isa + = new InetSocketAddress(InetAddress.getLocalHost(), server.port()); + Client client = new Client(isa); + client.start(); + if ((server.finish(FINISH_TIME) & client.finish(FINISH_TIME)) == 0) + throw new Exception("Failure"); + log.println(); + + } else if (args[0].equals("server")) { + + if (args.length > 1) + TEST_PORT = Integer.parseInt(args[1]); + Server server = new Server(TEST_PORT); + server.start(); + if (server.finish(FINISH_TIME) == 0) + throw new Exception("Failure"); + log.println(); + + } else if (args[0].equals("client")) { + + if (args.length < 2) { + log.println("No host specified: terminating."); + return; + } + String ip = args[1]; + if (args.length > 2) + TEST_PORT = Integer.parseInt(args[2]); + InetAddress ia = InetAddress.getByName(ip); + InetSocketAddress isa = new InetSocketAddress(ia, TEST_PORT); + Client client = new Client(isa); + client.start(); + if (client.finish(FINISH_TIME) == 0) + throw new Exception("Failure"); + log.println(); + + } else { + System.out.println("Usage note:"); + System.out.println("java SelectorTest [server] [client ] []"); + System.out.println("No arguments runs both client and server in separate threads using the default port of 31452."); + System.out.println("client runs the client on this machine and connects to the server specified."); + System.out.println("server runs the server on localhost."); + } + } + + static class Client extends TestThread { + InetSocketAddress isa; + Client(InetSocketAddress isa) { + super("Client", SelectorTest.log); + this.isa = isa; + } + + public void go() throws Exception { + log.println("starting client..."); + for (int i=0; i 0); + } + } + + static class Server extends TestThread { + private final ServerSocketChannel ssc; + private List socketList = new ArrayList(); + private ServerSocket ss; + private int connectionsAccepted = 0; + private Selector pollSelector; + private Selector acceptSelector; + private Set pkeys; + private Set pskeys; + + Server(int port) throws IOException { + super("Server", SelectorTest.log); + this.ssc = RdmaSockets.openServerSocketChannel( + StandardProtocolFamily.INET); + ssc.bind(new InetSocketAddress(InetAddress.getLocalHost(), port)); + } + + int port() { + return ssc.socket().getLocalPort(); + } + + public void go() throws Exception { + log.println("starting server..."); + acceptSelector = RdmaSockets.openSelector(); + pollSelector = RdmaSockets.openSelector(); + pkeys = pollSelector.keys(); + pskeys = pollSelector.selectedKeys(); + Set readyKeys = acceptSelector.selectedKeys(); + RequestHandler rh = new RequestHandler(pollSelector, log); + Thread requestThread = new Thread(rh); + + requestThread.start(); + + ssc.configureBlocking(false); + SelectionKey acceptKey = ssc.register(acceptSelector, + SelectionKey.OP_ACCEPT); + while(connectionsAccepted < SelectorTest.NUM_CLIENTS) { + int keysAdded = acceptSelector.select(100); + if (keysAdded > 0) { + Iterator i = readyKeys.iterator(); + while(i.hasNext()) { + SelectionKey sk = (SelectionKey)i.next(); + i.remove(); + ServerSocketChannel nextReady = + (ServerSocketChannel)sk.channel(); + SocketChannel sc = nextReady.accept(); + connectionsAccepted++; + if (sc != null) { + sc.configureBlocking(false); + synchronized (pkeys) { + sc.register(pollSelector, SelectionKey.OP_READ); + } + } else { + throw new RuntimeException( + "Socket does not support Channels"); + } + } + } + } + acceptKey.cancel(); + requestThread.join(); + acceptSelector.close(); + pollSelector.close(); + } + } +} + +class RemoteEntity { + private static Random rnd = new Random(); + int id; + ByteBuffer data; + int dataWrittenIndex; + int totalDataLength; + boolean initiated = false; + boolean connected = false; + boolean written = false; + boolean acked = false; + boolean closed = false; + private SocketChannel sc; + ByteBuffer ackBuffer; + PrintStream log; + InetSocketAddress server; + + RemoteEntity(int id, InetSocketAddress server, PrintStream log) + throws Exception + { + int connectFailures = 0; + this.id = id; + this.log = log; + this.server = server; + + sc = RdmaSockets.openSocketChannel(StandardProtocolFamily.INET); + sc.configureBlocking(false); + + // Prepare the data buffer to write out from this entity + // Let's use both slow and fast buffers + if (rnd.nextBoolean()) + data = ByteBuffer.allocateDirect(100); + else + data = ByteBuffer.allocate(100); + String number = Integer.toString(id); + if (number.length() == 1) + number = "0"+number; + String source = "Testing from " + number; + data.put(source.getBytes("8859_1")); + data.flip(); + totalDataLength = source.length(); + + // Allocate an ack buffer + ackBuffer = ByteBuffer.allocateDirect(10); + } + + private void reset() throws Exception { + sc.close(); + sc = RdmaSockets.openSocketChannel(StandardProtocolFamily.INET); + sc.configureBlocking(false); + } + + private void connect() throws Exception { + try { + connected = sc.connect(server); + initiated = true; + } catch (ConnectException e) { + initiated = false; + reset(); + } + } + + private void finishConnect() throws Exception { + try { + connected = sc.finishConnect(); + } catch (IOException e) { + initiated = false; + reset(); + } + } + + int id() { + return id; + } + + boolean cycle() throws Exception { + if (!initiated) + connect(); + else if (!connected) + finishConnect(); + else if (!written) + writeCycle(); + else if (!acked) + ackCycle(); + else if (!closed) + close(); + return closed; + } + + private void ackCycle() throws Exception { + //log.println("acking from "+id); + int bytesRead = sc.read(ackBuffer); + if (bytesRead > 0) { + acked = true; + } + } + + private void close() throws Exception { + sc.close(); + closed = true; + } + + private void writeCycle() throws Exception { + log.println("writing from "+id); + int numBytesToWrite = rnd.nextInt(10)+1; + int newWriteTarget = dataWrittenIndex + numBytesToWrite; + if (newWriteTarget > totalDataLength) + newWriteTarget = totalDataLength; + data.limit(newWriteTarget); + int bytesWritten = sc.write(data); + if (bytesWritten > 0) + dataWrittenIndex += bytesWritten; + if (dataWrittenIndex == totalDataLength) { + written = true; + sc.socket().shutdownOutput(); + } + } + +} + + +class RequestHandler implements Runnable { + private static Random rnd = new Random(); + private Selector selector; + private int connectionsHandled = 0; + private HashMap dataBin = new HashMap(); + PrintStream log; + + public RequestHandler(Selector selector, PrintStream log) { + this.selector = selector; + this.log = log; + } + + public void run() { + log.println("starting request handler..."); + int connectionsAccepted = 0; + + Set nKeys = selector.keys(); + Set readyKeys = selector.selectedKeys(); + + try { + while(connectionsHandled < SelectorTest.NUM_CLIENTS) { + int numKeys = selector.select(100); + + // Process channels with data + synchronized (nKeys) { + if (readyKeys.size() > 0) { + Iterator i = readyKeys.iterator(); + while(i.hasNext()) { + SelectionKey sk = (SelectionKey)i.next(); + i.remove(); + SocketChannel sc = (SocketChannel)sk.channel(); + if (sc.isOpen()) + read(sk, sc); + } + } + } + + // Give other threads a chance to run + if (numKeys == 0) { + try { + Thread.sleep(1); + } catch (Exception x) {} + } + } + } catch (Exception e) { + log.println("Unexpected error 1: "+e); + e.printStackTrace(); + } + } + + private void read(SelectionKey sk, SocketChannel sc) throws Exception { + ByteBuffer bin = (ByteBuffer)dataBin.get(sc); + if (bin == null) { + if (rnd.nextBoolean()) + bin = ByteBuffer.allocateDirect(100); + else + bin = ByteBuffer.allocate(100); + dataBin.put(sc, bin); + } + + int bytesRead = 0; + do { + bytesRead = sc.read(bin); + } while(bytesRead > 0); + + if (bytesRead == -1) { + sk.interestOps(0); + bin.flip(); + int size = bin.limit(); + byte[] data = new byte[size]; + for(int j=0; j>>"+message + "<<<"); + log.println("Handled: "+connectionsHandled); + } + } + + private void acknowledge(SocketChannel sc) throws Exception { + ByteBuffer ackBuffer = ByteBuffer.allocateDirect(10); + String s = "ack"; + ackBuffer.put(s.getBytes("8859_1")); + ackBuffer.flip(); + int bytesWritten = 0; + while(bytesWritten == 0) { + bytesWritten += sc.write(ackBuffer); + } + sc.close(); + } +}