1 /*
   2  * Copyright (c) 2000, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package rdma.ch;
  27 
  28 import java.io.IOException;
  29 import java.io.InputStream;
  30 import java.io.OutputStream;
  31 import java.net.InetAddress;
  32 import java.net.InetSocketAddress;
  33 import java.net.Socket;
  34 import java.net.SocketAddress;
  35 import java.net.SocketException;
  36 import java.net.SocketOption;
  37 import java.net.SocketTimeoutException;
  38 import java.net.StandardSocketOptions;
  39 import java.nio.ByteBuffer;
  40 import java.nio.channels.Channels;
  41 import java.nio.channels.ClosedChannelException;
  42 import java.nio.channels.IllegalBlockingModeException;
  43 import java.nio.channels.SocketChannel;
  44 import java.security.AccessController;
  45 import java.security.PrivilegedExceptionAction;
  46 import sun.nio.ch.ChannelInputStream;
  47 import sun.nio.ch.ExtendedSocketOption;
  48 
  49 import static java.util.concurrent.TimeUnit.*;
  50 
  51 class RdmaSocketAdaptor
  52     extends Socket
  53 {
  54     // The channel being adapted
  55     private final RdmaSocketChannelImpl sc;
  56 
  57     // Timeout "option" value for reads
  58     private volatile int timeout;
  59 
  60     private RdmaSocketAdaptor(RdmaSocketChannelImpl sc) throws SocketException {
  61         super((RdmaSocketImpl) null);
  62         this.sc = sc;
  63     }
  64 
  65     public static Socket create(RdmaSocketChannelImpl sc) {
  66         try {
  67             return new RdmaSocketAdaptor(sc);
  68         } catch (SocketException e) {
  69             throw new InternalError("Should not reach here");
  70         }
  71     }
  72 
  73     public SocketChannel getChannel() {
  74         return sc;
  75     }
  76 
  77     // Override this method just to protect against changes in the superclass
  78     //
  79     public void connect(SocketAddress remote) throws IOException {
  80         connect(remote, 0);
  81     }
  82 
  83     public void connect(SocketAddress remote, int timeout) throws IOException {
  84         if (remote == null)
  85             throw new IllegalArgumentException("connect: The address can't be null");
  86         if (timeout < 0)
  87             throw new IllegalArgumentException("connect: timeout can't be negative");
  88 
  89         synchronized (sc.blockingLock()) {
  90             if (!sc.isBlocking())
  91                 throw new IllegalBlockingModeException();
  92 
  93             try {
  94                 if (timeout == 0) {
  95                     sc.connect(remote);
  96                     return;
  97                 }
  98 
  99                 sc.configureBlocking(false);
 100                 try {
 101                     if (sc.connect(remote))
 102                         return;
 103                 } finally {
 104                     try {
 105                         sc.configureBlocking(true);
 106                     } catch (ClosedChannelException e) { }
 107                 }
 108 
 109                 long timeoutNanos = NANOSECONDS.convert(timeout, MILLISECONDS);
 110                 long to = timeout;
 111                 for (;;) {
 112                     long startTime = System.nanoTime();
 113                     if (sc.pollConnected(to)) {
 114                         boolean connected = sc.finishConnect();
 115                         assert connected;
 116                         break;
 117                     }
 118                     timeoutNanos -= System.nanoTime() - startTime;
 119                     if (timeoutNanos <= 0) {
 120                         try {
 121                             sc.close();
 122                         } catch (IOException x) { }
 123                         throw new SocketTimeoutException();
 124                     }
 125                     to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS);
 126                 }
 127 
 128             } catch (Exception x) {
 129                 RdmaNet.translateException(x, true);
 130             }
 131         }
 132 
 133     }
 134 
 135     public void bind(SocketAddress local) throws IOException {
 136         try {
 137             sc.bind(local);
 138         } catch (Exception x) {
 139             RdmaNet.translateException(x);
 140         }
 141     }
 142 
 143     public InetAddress getInetAddress() {
 144         InetSocketAddress remote = sc.remoteAddress();
 145         if (remote == null) {
 146             return null;
 147         } else {
 148             return remote.getAddress();
 149         }
 150     }
 151 
 152     public InetAddress getLocalAddress() {
 153         if (sc.isOpen()) {
 154             InetSocketAddress local = sc.localAddress();
 155             if (local != null) {
 156                 return RdmaNet.getRevealedLocalAddress(local).getAddress();
 157             }
 158         }
 159         return new InetSocketAddress(0).getAddress();
 160     }
 161 
 162     public int getPort() {
 163         InetSocketAddress remote = sc.remoteAddress();
 164         if (remote == null) {
 165             return 0;
 166         } else {
 167             return remote.getPort();
 168         }
 169     }
 170 
 171     public int getLocalPort() {
 172         InetSocketAddress local = sc.localAddress();
 173         if (local == null) {
 174             return -1;
 175         } else {
 176             return local.getPort();
 177         }
 178     }
 179 
 180     private class SocketInputStream
 181         extends ChannelInputStream
 182     {
 183         private SocketInputStream() {
 184             super(sc);
 185         }
 186 
 187         protected int read(ByteBuffer bb)
 188             throws IOException
 189         {
 190             synchronized (sc.blockingLock()) {
 191                 if (!sc.isBlocking())
 192                     throw new IllegalBlockingModeException();
 193 
 194                 // no timeout
 195                 long to = RdmaSocketAdaptor.this.timeout;
 196                 if (to == 0)
 197                     return sc.read(bb);
 198 
 199                 // timed read
 200                 long timeoutNanos = NANOSECONDS.convert(to, MILLISECONDS);
 201                 for (;;) {
 202                     long startTime = System.nanoTime();
 203                     if (sc.pollRead(to)) {
 204                         return sc.read(bb);
 205                     }
 206                     timeoutNanos -= System.nanoTime() - startTime;
 207                     if (timeoutNanos <= 0)
 208                         throw new SocketTimeoutException();
 209                     to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS);
 210                 }
 211             }
 212         }
 213     }
 214 
 215     private InputStream socketInputStream = null;
 216 
 217     public InputStream getInputStream() throws IOException {
 218         if (!sc.isOpen())
 219             throw new SocketException("Socket is closed");
 220         if (!sc.isConnected())
 221             throw new SocketException("Socket is not connected");
 222         if (!sc.isInputOpen())
 223             throw new SocketException("Socket input is shutdown");
 224         if (socketInputStream == null) {
 225             try {
 226                 socketInputStream = AccessController.doPrivileged(
 227                     new PrivilegedExceptionAction<InputStream>() {
 228                         public InputStream run() throws IOException {
 229                             return new SocketInputStream();
 230                         }
 231                     });
 232             } catch (java.security.PrivilegedActionException e) {
 233                 throw (IOException)e.getException();
 234             }
 235         }
 236         return socketInputStream;
 237     }
 238 
 239     public OutputStream getOutputStream() throws IOException {
 240         if (!sc.isOpen())
 241             throw new SocketException("Socket is closed");
 242         if (!sc.isConnected())
 243             throw new SocketException("Socket is not connected");
 244         if (!sc.isOutputOpen())
 245             throw new SocketException("Socket output is shutdown");
 246         OutputStream os = null;
 247         try {
 248             os = AccessController.doPrivileged(
 249                 new PrivilegedExceptionAction<OutputStream>() {
 250                     public OutputStream run() throws IOException {
 251                         return Channels.newOutputStream(sc);
 252                     }
 253                 });
 254         } catch (java.security.PrivilegedActionException e) {
 255             throw (IOException)e.getException();
 256         }
 257         return os;
 258     }
 259 
 260     private void setBooleanOption(SocketOption<Boolean> name, boolean value)
 261         throws SocketException
 262     {
 263         try {
 264             sc.setOption(name, value);
 265         } catch (IOException x) {
 266             RdmaNet.translateToSocketException(x);
 267         }
 268     }
 269 
 270     private void setIntOption(SocketOption<Integer> name, int value)
 271         throws SocketException
 272     {
 273         try {
 274             sc.setOption(name, value);
 275         } catch (IOException x) {
 276             RdmaNet.translateToSocketException(x);
 277         }
 278     }
 279 
 280     private boolean getBooleanOption(SocketOption<Boolean> name) throws SocketException {
 281         try {
 282             return sc.getOption(name).booleanValue();
 283         } catch (IOException x) {
 284             RdmaNet.translateToSocketException(x);
 285             return false;       // keep compiler happy
 286         }
 287     }
 288 
 289     private int getIntOption(SocketOption<Integer> name) throws SocketException {
 290         try {
 291             return sc.getOption(name).intValue();
 292         } catch (IOException x) {
 293             RdmaNet.translateToSocketException(x);
 294             return -1;          // keep compiler happy
 295         }
 296     }
 297 
 298     public void setTcpNoDelay(boolean on) throws SocketException {
 299         setBooleanOption(StandardSocketOptions.TCP_NODELAY, on);
 300     }
 301 
 302     public boolean getTcpNoDelay() throws SocketException {
 303         return getBooleanOption(StandardSocketOptions.TCP_NODELAY);
 304     }
 305 
 306 
 307     public void setSoLinger(boolean on, int linger) throws SocketException {
 308         if (!on)
 309             linger = -1;
 310         setIntOption(StandardSocketOptions.SO_LINGER, linger);
 311     }
 312 
 313     public int getSoLinger() throws SocketException {
 314         return getIntOption(StandardSocketOptions.SO_LINGER);
 315     }
 316 
 317     public void sendUrgentData(int data) throws IOException {
 318         int n = sc.sendOutOfBandData((byte) data);
 319         if (n == 0)
 320             throw new IOException("Socket buffer full");
 321     }
 322 
 323     public void setSoTimeout(int timeout) throws SocketException {
 324         if (timeout < 0)
 325             throw new IllegalArgumentException("timeout can't be negative");
 326         this.timeout = timeout;
 327     }
 328 
 329     public int getSoTimeout() throws SocketException {
 330         return timeout;
 331     }
 332 
 333     public void setSendBufferSize(int size) throws SocketException {
 334         if (size <= 0)
 335             throw new IllegalArgumentException("Invalid send size");
 336         setIntOption(StandardSocketOptions.SO_SNDBUF, size);
 337     }
 338 
 339     public int getSendBufferSize() throws SocketException {
 340         return getIntOption(StandardSocketOptions.SO_SNDBUF);
 341     }
 342 
 343     public void setReceiveBufferSize(int size) throws SocketException {
 344         if (size <= 0)
 345             throw new IllegalArgumentException("Invalid receive size");
 346         setIntOption(StandardSocketOptions.SO_RCVBUF, size);
 347     }
 348 
 349     public int getReceiveBufferSize() throws SocketException {
 350         return getIntOption(StandardSocketOptions.SO_RCVBUF);
 351     }
 352 
 353     public void setReuseAddress(boolean on) throws SocketException {
 354         setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on);
 355     }
 356 
 357     public boolean getReuseAddress() throws SocketException {
 358         return getBooleanOption(StandardSocketOptions.SO_REUSEADDR);
 359     }
 360 
 361     public void close() throws IOException {
 362         sc.close();
 363     }
 364 
 365     public void shutdownInput() throws IOException {
 366         try {
 367             sc.shutdownInput();
 368         } catch (Exception x) {
 369             RdmaNet.translateException(x);
 370         }
 371     }
 372 
 373     public void shutdownOutput() throws IOException {
 374         try {
 375             sc.shutdownOutput();
 376         } catch (Exception x) {
 377             RdmaNet.translateException(x);
 378         }
 379     }
 380 
 381     public String toString() {
 382         if (sc.isConnected())
 383             return "RdmaSocket[addr=" + getInetAddress() +
 384                 ",port=" + getPort() +
 385                 ",localport=" + getLocalPort() + "]";
 386         return "RdmaSocket[unconnected]";
 387     }
 388 
 389     public boolean isConnected() {
 390         return sc.isConnected();
 391     }
 392 
 393     public boolean isBound() {
 394         return sc.localAddress() != null;
 395     }
 396 
 397     public boolean isClosed() {
 398         return !sc.isOpen();
 399     }
 400 
 401     public boolean isInputShutdown() {
 402         return !sc.isInputOpen();
 403     }
 404 
 405     public boolean isOutputShutdown() {
 406         return !sc.isOutputOpen();
 407     }
 408 }