1 /*
   2  * Copyright (c) 2000, 2018, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package rdma.ch;
  27 
  28 import java.io.IOException;
  29 import java.io.InputStream;
  30 import java.io.OutputStream;
  31 import java.net.InetAddress;
  32 import java.net.InetSocketAddress;
  33 import java.net.Socket;
  34 import java.net.SocketAddress;
  35 import java.net.SocketException;
  36 import java.net.SocketOption;
  37 import java.net.SocketTimeoutException;
  38 import java.net.StandardSocketOptions;
  39 import java.nio.ByteBuffer;
  40 import java.nio.channels.Channels;
  41 import java.nio.channels.ClosedChannelException;
  42 import java.nio.channels.IllegalBlockingModeException;
  43 import java.nio.channels.SocketChannel;
  44 import java.security.AccessController;
  45 import java.security.PrivilegedExceptionAction;
  46 import sun.nio.ch.ChannelInputStream;
  47 
  48 import static java.util.concurrent.TimeUnit.*;
  49 
  50 class RdmaSocketAdaptor
  51     extends Socket
  52 {
  53     // The channel being adapted
  54     private final RdmaSocketChannelImpl sc;
  55 
  56     // Timeout "option" value for reads
  57     private volatile int timeout;
  58 
  59     private RdmaSocketAdaptor(RdmaSocketChannelImpl sc) throws SocketException {
  60         super((RdmaSocketImpl) null);
  61         this.sc = sc;
  62     }
  63 
  64     public static Socket create(RdmaSocketChannelImpl sc) {
  65         try {
  66             return new RdmaSocketAdaptor(sc);
  67         } catch (SocketException e) {
  68             throw new InternalError("Should not reach here");
  69         }
  70     }
  71 
  72     public SocketChannel getChannel() {
  73         return sc;
  74     }
  75 
  76     // Override this method just to protect against changes in the superclass
  77     //
  78     public void connect(SocketAddress remote) throws IOException {
  79         connect(remote, 0);
  80     }
  81 
  82     public void connect(SocketAddress remote, int timeout) throws IOException {
  83         if (remote == null)
  84             throw new IllegalArgumentException("connect: The address can't be null");
  85         if (timeout < 0)
  86             throw new IllegalArgumentException("connect: timeout can't be negative");
  87 
  88         synchronized (sc.blockingLock()) {
  89             if (!sc.isBlocking())
  90                 throw new IllegalBlockingModeException();
  91 
  92             try {
  93                 if (timeout == 0) {
  94                     sc.connect(remote);
  95                     return;
  96                 }
  97 
  98                 sc.configureBlocking(false);
  99                 try {
 100                     if (sc.connect(remote))
 101                         return;
 102                 } finally {
 103                     try {
 104                         sc.configureBlocking(true);
 105                     } catch (ClosedChannelException e) { }
 106                 }
 107 
 108                 long timeoutNanos = NANOSECONDS.convert(timeout, MILLISECONDS);
 109                 long to = timeout;
 110                 for (;;) {
 111                     long startTime = System.nanoTime();
 112                     if (sc.pollConnected(to)) {
 113                         boolean connected = sc.finishConnect();
 114                         assert connected;
 115                         break;
 116                     }
 117                     timeoutNanos -= System.nanoTime() - startTime;
 118                     if (timeoutNanos <= 0) {
 119                         try {
 120                             sc.close();
 121                         } catch (IOException x) { }
 122                         throw new SocketTimeoutException();
 123                     }
 124                     to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS);
 125                 }
 126 
 127             } catch (Exception x) {
 128                 RdmaNet.translateException(x, true);
 129             }
 130         }
 131 
 132     }
 133 
 134     public void bind(SocketAddress local) throws IOException {
 135         try {
 136             sc.bind(local);
 137         } catch (Exception x) {
 138             RdmaNet.translateException(x);
 139         }
 140     }
 141 
 142     public InetAddress getInetAddress() {
 143         InetSocketAddress remote = sc.remoteAddress();
 144         if (remote == null) {
 145             return null;
 146         } else {
 147             return remote.getAddress();
 148         }
 149     }
 150 
 151     public InetAddress getLocalAddress() {
 152         if (sc.isOpen()) {
 153             InetSocketAddress local = sc.localAddress();
 154             if (local != null) {
 155                 return RdmaNet.getRevealedLocalAddress(local).getAddress();
 156             }
 157         }
 158         return new InetSocketAddress(0).getAddress();
 159     }
 160 
 161     public int getPort() {
 162         InetSocketAddress remote = sc.remoteAddress();
 163         if (remote == null) {
 164             return 0;
 165         } else {
 166             return remote.getPort();
 167         }
 168     }
 169 
 170     public int getLocalPort() {
 171         InetSocketAddress local = sc.localAddress();
 172         if (local == null) {
 173             return -1;
 174         } else {
 175             return local.getPort();
 176         }
 177     }
 178 
 179     private class SocketInputStream
 180         extends ChannelInputStream
 181     {
 182         private SocketInputStream() {
 183             super(sc);
 184         }
 185 
 186         protected int read(ByteBuffer bb)
 187             throws IOException
 188         {
 189             synchronized (sc.blockingLock()) {
 190                 if (!sc.isBlocking())
 191                     throw new IllegalBlockingModeException();
 192 
 193                 // no timeout
 194                 long to = RdmaSocketAdaptor.this.timeout;
 195                 if (to == 0)
 196                     return sc.read(bb);
 197 
 198                 // timed read
 199                 long timeoutNanos = NANOSECONDS.convert(to, MILLISECONDS);
 200                 for (;;) {
 201                     long startTime = System.nanoTime();
 202                     if (sc.pollRead(to)) {
 203                         return sc.read(bb);
 204                     }
 205                     timeoutNanos -= System.nanoTime() - startTime;
 206                     if (timeoutNanos <= 0)
 207                         throw new SocketTimeoutException();
 208                     to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS);
 209                 }
 210             }
 211         }
 212     }
 213 
 214     private InputStream socketInputStream = null;
 215 
 216     public InputStream getInputStream() throws IOException {
 217         if (!sc.isOpen())
 218             throw new SocketException("Socket is closed");
 219         if (!sc.isConnected())
 220             throw new SocketException("Socket is not connected");
 221         if (!sc.isInputOpen())
 222             throw new SocketException("Socket input is shutdown");
 223         if (socketInputStream == null) {
 224             try {
 225                 socketInputStream = AccessController.doPrivileged(
 226                     new PrivilegedExceptionAction<InputStream>() {
 227                         public InputStream run() throws IOException {
 228                             return new SocketInputStream();
 229                         }
 230                     });
 231             } catch (java.security.PrivilegedActionException e) {
 232                 throw (IOException)e.getException();
 233             }
 234         }
 235         return socketInputStream;
 236     }
 237 
 238     public OutputStream getOutputStream() throws IOException {
 239         if (!sc.isOpen())
 240             throw new SocketException("Socket is closed");
 241         if (!sc.isConnected())
 242             throw new SocketException("Socket is not connected");
 243         if (!sc.isOutputOpen())
 244             throw new SocketException("Socket output is shutdown");
 245         OutputStream os = null;
 246         try {
 247             os = AccessController.doPrivileged(
 248                 new PrivilegedExceptionAction<OutputStream>() {
 249                     public OutputStream run() throws IOException {
 250                         return Channels.newOutputStream(sc);
 251                     }
 252                 });
 253         } catch (java.security.PrivilegedActionException e) {
 254             throw (IOException)e.getException();
 255         }
 256         return os;
 257     }
 258 
 259     private void setBooleanOption(SocketOption<Boolean> name, boolean value)
 260         throws SocketException
 261     {
 262         try {
 263             sc.setOption(name, value);
 264         } catch (IOException x) {
 265             RdmaNet.translateToSocketException(x);
 266         }
 267     }
 268 
 269     private void setIntOption(SocketOption<Integer> name, int value)
 270         throws SocketException
 271     {
 272         try {
 273             sc.setOption(name, value);
 274         } catch (IOException x) {
 275             RdmaNet.translateToSocketException(x);
 276         }
 277     }
 278 
 279     private boolean getBooleanOption(SocketOption<Boolean> name) throws SocketException {
 280         try {
 281             return sc.getOption(name).booleanValue();
 282         } catch (IOException x) {
 283             RdmaNet.translateToSocketException(x);
 284             return false;       // keep compiler happy
 285         }
 286     }
 287 
 288     private int getIntOption(SocketOption<Integer> name) throws SocketException {
 289         try {
 290             return sc.getOption(name).intValue();
 291         } catch (IOException x) {
 292             RdmaNet.translateToSocketException(x);
 293             return -1;          // keep compiler happy
 294         }
 295     }
 296 
 297     public void setTcpNoDelay(boolean on) throws SocketException {
 298         setBooleanOption(StandardSocketOptions.TCP_NODELAY, on);
 299     }
 300 
 301     public boolean getTcpNoDelay() throws SocketException {
 302         return getBooleanOption(StandardSocketOptions.TCP_NODELAY);
 303     }
 304 
 305 
 306     public void setSoLinger(boolean on, int linger) throws SocketException {
 307         if (!on)
 308             linger = -1;
 309         setIntOption(StandardSocketOptions.SO_LINGER, linger);
 310     }
 311 
 312     public int getSoLinger() throws SocketException {
 313         return getIntOption(StandardSocketOptions.SO_LINGER);
 314     }
 315 
 316     public void sendUrgentData(int data) throws IOException {
 317         int n = sc.sendOutOfBandData((byte) data);
 318         if (n == 0)
 319             throw new IOException("Socket buffer full");
 320     }
 321 
 322     public void setSoTimeout(int timeout) throws SocketException {
 323         if (timeout < 0)
 324             throw new IllegalArgumentException("timeout can't be negative");
 325         this.timeout = timeout;
 326     }
 327 
 328     public int getSoTimeout() throws SocketException {
 329         return timeout;
 330     }
 331 
 332     public void setSendBufferSize(int size) throws SocketException {
 333         if (size <= 0)
 334             throw new IllegalArgumentException("Invalid send size");
 335         setIntOption(StandardSocketOptions.SO_SNDBUF, size);
 336     }
 337 
 338     public int getSendBufferSize() throws SocketException {
 339         return getIntOption(StandardSocketOptions.SO_SNDBUF);
 340     }
 341 
 342     public void setReceiveBufferSize(int size) throws SocketException {
 343         if (size <= 0)
 344             throw new IllegalArgumentException("Invalid receive size");
 345         setIntOption(StandardSocketOptions.SO_RCVBUF, size);
 346     }
 347 
 348     public int getReceiveBufferSize() throws SocketException {
 349         return getIntOption(StandardSocketOptions.SO_RCVBUF);
 350     }
 351 
 352     public void setReuseAddress(boolean on) throws SocketException {
 353         setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on);
 354     }
 355 
 356     public boolean getReuseAddress() throws SocketException {
 357         return getBooleanOption(StandardSocketOptions.SO_REUSEADDR);
 358     }
 359 
 360     public void close() throws IOException {
 361         sc.close();
 362     }
 363 
 364     public void shutdownInput() throws IOException {
 365         try {
 366             sc.shutdownInput();
 367         } catch (Exception x) {
 368             RdmaNet.translateException(x);
 369         }
 370     }
 371 
 372     public void shutdownOutput() throws IOException {
 373         try {
 374             sc.shutdownOutput();
 375         } catch (Exception x) {
 376             RdmaNet.translateException(x);
 377         }
 378     }
 379 
 380     public String toString() {
 381         if (sc.isConnected())
 382             return "RdmaSocket[addr=" + getInetAddress() +
 383                 ",port=" + getPort() +
 384                 ",localport=" + getLocalPort() + "]";
 385         return "RdmaSocket[unconnected]";
 386     }
 387 
 388     public boolean isConnected() {
 389         return sc.isConnected();
 390     }
 391 
 392     public boolean isBound() {
 393         return sc.localAddress() != null;
 394     }
 395 
 396     public boolean isClosed() {
 397         return !sc.isOpen();
 398     }
 399 
 400     public boolean isInputShutdown() {
 401         return !sc.isInputOpen();
 402     }
 403 
 404     public boolean isOutputShutdown() {
 405         return !sc.isOutputOpen();
 406     }
 407 }