< prev index next >

test/jdk/jdk/net/RdmaSockets/rsocket/SocketChannel/BasicSocketChannelTest.java

Print this page
rev 52802 : [mq]: BasicSocketChannelTest.24
rev 52801 : imported patch jdk12-8195160-version24.patch

@@ -35,10 +35,11 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.UncheckedIOException;
 import java.net.*;
+import java.nio.channels.ClosedChannelException;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
 import java.nio.ByteBuffer;
 import jdk.net.RdmaSockets;
 import jtreg.SkippedException;

@@ -85,33 +86,34 @@
                         int inputNum = 0;
                         while (inputNum < MESSAGE_LENGTH) {
                             inputNum += sc1.write(input);
                         }
                         sc1.shutdownInput();
+                        assertInputShutdown(sc1);
                         sc1.shutdownOutput();
+                        assertOutputShutdown(sc1);
                     } catch (IOException e) {
                         throw new UncheckedIOException(e);
                     }
                 }
             };
             t.start();
 
             try (SocketChannel sc2 = ssc.accept()) {
-                ByteBuffer output = ByteBuffer.allocate(MESSAGE_LENGTH);
-                //while (sc2.read(output) != -1);  // TODO: why no EOF??
-                int outputNum = 0;
-                while (outputNum < MESSAGE_LENGTH) {
-                    outputNum += sc2.read(output);
-                }
+                ByteBuffer output = ByteBuffer.allocate(MESSAGE_LENGTH + 1);
+                while (sc2.read(output) != -1);
+                output.flip();
 
-                String result = new String(output.array(), UTF_8);
+                String result = UTF_8.decode(output).toString();
                 if (!result.equals(MESSAGE)) {
                     String msg = format("Expected [%s], received [%s]", MESSAGE, result);
                     throw new RuntimeException("Test Failed! " + msg);
                 }
                 sc2.shutdownInput();
+                assertInputShutdown(sc2);
                 sc2.shutdownOutput();
+                assertOutputShutdown(sc2);
             }
             t.join();
             out.printf("passed%n");
         }
     }

@@ -129,45 +131,118 @@
                     try (SocketChannel sc1 = RdmaSockets.openSocketChannel(family);
                          Socket client = sc1.socket()) {
                         client.connect(new InetSocketAddress(iaddr, port));
                         if (!client.isConnected())
                             throw new RuntimeException("Test Failed!");
+                        InputStream is = client.getInputStream();
                         OutputStream os = client.getOutputStream();
                         os.write(MESSAGE.getBytes(UTF_8));
+
                         client.shutdownInput();
+                        assertInputShutdown(sc1);
+                        assertInputShutdown(client, is);
                         client.shutdownOutput();
-                        if (!client.isInputShutdown()) {
-                            throw new RuntimeException("Unexpected open input:" + client);
-                        }
-                        if (!client.isOutputShutdown()) {
-                            throw new RuntimeException("Unexpected open output:" + client);
-                        }
+                        assertOutputShutdown(sc1);
+                        assertOutputShutdown(client, os);
                     } catch (IOException e) {
                         throw new UncheckedIOException(e);
                     }
                 }
             };
             t.start();
 
             try (Socket conn = ssc.socket().accept()) {
                 InputStream is = conn.getInputStream();
+                OutputStream os = conn.getOutputStream();
                 byte[] buf = is.readAllBytes();
 
                 String result = new String(buf, UTF_8);
                 if (!result.equals(MESSAGE)) {
                     String msg = format("Expected [%s], received [%s]", MESSAGE, result);
                     throw new RuntimeException("Test Failed! " + msg);
                 }
                 conn.shutdownInput();
+                assertInputShutdown(conn, is);
                 conn.shutdownOutput();
-                if (!conn.isInputShutdown()) {
-                    throw new RuntimeException("Unexpected open input: " + conn);
-                }
-                if (!conn.isOutputShutdown()) {
-                    throw new RuntimeException("Unexpected open output: " + conn);
-                }
+                assertOutputShutdown(conn, os);
             }
             t.join();
             out.printf("passed%n");
         }
     }
+
+    static void assertInputShutdown(SocketChannel sc) throws IOException {
+        ByteBuffer bb = ByteBuffer.allocate(1);
+        int r = sc.read(bb);
+        if (r != -1)
+            throw new RuntimeException(format("Unexpected read of %d bytes", r));
+
+        try {
+            sc.socket().getInputStream();
+        } catch (IOException expected) {
+            String msg = expected.getMessage();
+            if (!msg.contains("input"))
+                throw new RuntimeException("Expected to find \"input\" in " + expected);
+            if (!msg.contains("shutdown"))
+                throw new RuntimeException("Expected to find \"shutdown\" in " + expected);
+        }
+    }
+
+    static void assertOutputShutdown(SocketChannel sc) throws IOException {
+        ByteBuffer bb = ByteBuffer.allocate(1);
+        bb.put((byte)0x05);
+        bb.flip();
+        try {
+            sc.write(bb);
+            throw new RuntimeException("Unexpected write of bytes");
+        } catch (ClosedChannelException expected) { }
+
+        try {
+            sc.socket().getOutputStream();
+        } catch (IOException expected) {
+            String msg = expected.getMessage();
+            if (!msg.contains("output"))
+                throw new RuntimeException("Expected to find \"output\" in " + expected);
+            if (!msg.contains("shutdown"))
+                throw new RuntimeException("Expected to find \"shutdown\" in " + expected);
+        }
+    }
+
+    static void assertInputShutdown(Socket s, InputStream is) throws IOException {
+        if (!s.isInputShutdown()) {
+            throw new RuntimeException("Unexpected open input: " + s);
+        }
+        int r;
+        if ((r = is.read()) != -1)
+            throw new RuntimeException(format("Unexpected read of %d", r));
+
+        try {
+            s.getInputStream();
+        } catch (IOException expected) {
+            String msg = expected.getMessage();
+            if (!msg.contains("input"))
+                throw new RuntimeException("Expected to find \"input\" in " + expected);
+            if (!msg.contains("shutdown"))
+                throw new RuntimeException("Expected to find \"shutdown\" in " + expected);
+        }
+    }
+
+    static void assertOutputShutdown(Socket s, OutputStream os) throws IOException {
+        if (!s.isOutputShutdown()) {
+            throw new RuntimeException("Unexpected open output: " + s);
+        }
+        try {
+            os.write((byte)0x07);
+            throw new RuntimeException("Unexpected write of bytes");
+        } catch (ClosedChannelException expected) { }
+
+        try {
+            s.getOutputStream();
+        } catch (IOException expected) {
+            String msg = expected.getMessage();
+            if (!msg.contains("output"))
+                throw new RuntimeException("Expected to find \"output\" in " + expected);
+            if (!msg.contains("shutdown"))
+                throw new RuntimeException("Expected to find \"shutdown\" in " + expected);
+        }
+    }
 }
< prev index next >