1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /**
  25  * @test
  26  * @run testng ConnectionReset
  27  * @summary Test behavior of SocketChannel.read and the Socket adaptor read
  28  *          and available methods when a connection is reset
  29  */
  30 
  31 import java.io.InputStream;
  32 import java.io.IOException;
  33 import java.net.InetAddress;
  34 import java.net.InetSocketAddress;
  35 import java.net.ServerSocket;
  36 import java.net.Socket;
  37 import java.nio.ByteBuffer;
  38 import java.nio.channels.SocketChannel;
  39 import java.lang.reflect.Method;
  40 
  41 import org.testng.annotations.Test;
  42 import static org.testng.Assert.*;
  43 
  44 @Test
  45 public class ConnectionReset {
  46 
  47     static final int REPEAT_COUNT = 5;
  48 
  49     /**
  50      * Tests SocketChannel.read when the connection is reset and there are no
  51      * bytes to read.
  52      */
  53     public void testSocketChannelReadNoData() throws IOException {
  54         System.out.println("testSocketChannelReadNoData");
  55         withResetConnection(null, sc -> {
  56             ByteBuffer bb = ByteBuffer.allocate(100);
  57             for (int i=0; i<REPEAT_COUNT; i++) {
  58                 try {
  59                     sc.read(bb);
  60                     assertTrue(false);
  61                 } catch (IOException ioe) {
  62                     System.out.format("read => %s (expected)%n", ioe);
  63                 }
  64             }
  65         });
  66     }
  67 
  68     /**
  69      * Tests SocketChannel.read when the connection is reset and there are bytes
  70      * to read.
  71      */
  72     public void testSocketChannelReadData() throws IOException {
  73         System.out.println("testSocketChannelReadData");
  74         byte[] data = { 1, 2, 3 };
  75         withResetConnection(data, sc -> {
  76             int remaining = data.length;
  77             ByteBuffer bb = ByteBuffer.allocate(remaining + 100);
  78             for (int i=0; i<REPEAT_COUNT; i++) {
  79                 try {
  80                     int bytesRead = sc.read(bb);
  81                     if (bytesRead == -1) {
  82                         System.out.println("read => EOF");
  83                     } else {
  84                         System.out.println("read => " + bytesRead + " byte(s)");
  85                     }
  86                     assertTrue(bytesRead > 0);
  87                     remaining -= bytesRead;
  88                     assertTrue(remaining >= 0);
  89                 } catch (IOException ioe) {
  90                     System.out.format("read => %s%n", ioe);
  91                     remaining = 0;
  92                 }
  93             }
  94         });
  95     }
  96 
  97 
  98     /**
  99      * Tests available before Socket read when the connection is reset and there
 100      * are no bytes to read.
 101      */
 102     public void testAvailableBeforeSocketReadNoData() throws IOException {
 103         System.out.println("testAvailableBeforeSocketReadNoData");
 104         withResetConnection(null, sc -> {
 105             Socket s = sc.socket();
 106             InputStream in = s.getInputStream();
 107             for (int i=0; i<REPEAT_COUNT; i++) {
 108                 int bytesAvailable = in.available();
 109                 System.out.format("available => %d%n", bytesAvailable);
 110                 assertTrue(bytesAvailable == 0);
 111                 try {
 112                     int bytesRead = in.read();
 113                     if (bytesRead == -1) {
 114                         System.out.println("read => EOF");
 115                     } else {
 116                         System.out.println("read => 1 byte");
 117                     }
 118                     assertTrue(false);
 119                 } catch (IOException ioe) {
 120                     System.out.format("read => %s (expected)%n", ioe);
 121                 }
 122             }
 123         });
 124     }
 125 
 126     /**
 127      * Tests available before Socket read when the connection is reset and there
 128      * are bytes to read.
 129      */
 130     public void testAvailableBeforeSocketReadData() throws IOException {
 131         System.out.println("testAvailableBeforeSocketReadData");
 132         byte[] data = { 1, 2, 3 };
 133         withResetConnection(data, sc -> {
 134             Socket s = sc.socket();
 135             InputStream in = s.getInputStream();
 136             int remaining = data.length;
 137             for (int i=0; i<REPEAT_COUNT; i++) {
 138                 int bytesAvailable = in.available();
 139                 System.out.format("available => %d%n", bytesAvailable);
 140                 assertTrue(bytesAvailable <= remaining);
 141                 try {
 142                     int bytesRead = in.read();
 143                     if (bytesRead == -1) {
 144                         System.out.println("read => EOF");
 145                         assertTrue(false);
 146                     } else {
 147                         System.out.println("read => 1 byte");
 148                         assertTrue(remaining > 0);
 149                         remaining--;
 150                     }
 151                 } catch (IOException ioe) {
 152                     System.out.format("read => %s%n", ioe);
 153                     remaining = 0;
 154                 }
 155             }
 156         });
 157     }
 158 
 159     /**
 160      * Tests Socket read before available when the connection is reset and there
 161      * are no bytes to read.
 162      */
 163     public void testSocketReadNoDataBeforeAvailable() throws IOException {
 164         System.out.println("testSocketReadNoDataBeforeAvailable");
 165         withResetConnection(null, sc -> {
 166             Socket s = sc.socket();
 167             InputStream in = s.getInputStream();
 168             for (int i=0; i<REPEAT_COUNT; i++) {
 169                 try {
 170                     int bytesRead = in.read();
 171                     if (bytesRead == -1) {
 172                         System.out.println("read => EOF");
 173                     } else {
 174                         System.out.println("read => 1 byte");
 175                     }
 176                     assertTrue(false);
 177                 } catch (IOException ioe) {
 178                     System.out.format("read => %s (expected)%n", ioe);
 179                 }
 180                 int bytesAvailable = in.available();
 181                 System.out.format("available => %d%n", bytesAvailable);
 182                 assertTrue(bytesAvailable == 0);
 183             }
 184         });
 185     }
 186 
 187     /**
 188      * Tests Socket read before available when the connection is reset and there
 189      * are bytes to read.
 190      */
 191     public void testSocketReadDataBeforeAvailable() throws IOException {
 192         System.out.println("testSocketReadDataBeforeAvailable");
 193         byte[] data = { 1, 2, 3 };
 194         withResetConnection(data, sc -> {
 195             Socket s = sc.socket();
 196             InputStream in = s.getInputStream();
 197             int remaining = data.length;
 198             for (int i=0; i<REPEAT_COUNT; i++) {
 199                 try {
 200                     int bytesRead = in.read();
 201                     if (bytesRead == -1) {
 202                         System.out.println("read => EOF");
 203                         assertTrue(false);
 204                     } else {
 205                         System.out.println("read => 1 byte");
 206                         assertTrue(remaining > 0);
 207                         remaining--;
 208                     }
 209                 } catch (IOException ioe) {
 210                     System.out.format("read => %s%n", ioe);
 211                     remaining = 0;
 212                 }
 213                 int bytesAvailable = in.available();
 214                 System.out.format("available => %d%n", bytesAvailable);
 215                 assertTrue(bytesAvailable <= remaining);
 216             }
 217         });
 218     }
 219 
 220     interface ThrowingConsumer<T> {
 221         void accept(T t) throws IOException;
 222     }
 223 
 224     /**
 225      * Invokes a consumer with a SocketChannel connected to a peer that has closed
 226      * the connection with a "connection reset". The peer sends the given data
 227      * bytes before closing (when data is not null).
 228      */
 229     static void withResetConnection(byte[] data, ThrowingConsumer<SocketChannel> consumer)
 230         throws IOException
 231     {
 232         var loopback = InetAddress.getLoopbackAddress();
 233         try (var listener = new ServerSocket()) {
 234             listener.bind(new InetSocketAddress(loopback, 0));
 235             try (var sc = SocketChannel.open()) {
 236                 sc.connect(listener.getLocalSocketAddress());
 237                 try (Socket peer = listener.accept()) {
 238                     if (data != null) {
 239                         peer.getOutputStream().write(data);
 240                     }
 241                     peer.setSoLinger(true, 0);
 242                 }
 243                 consumer.accept(sc);
 244             }
 245         }
 246     }
 247 }