1 /* 2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package jdk.internal.net.rdma; 27 28 import java.io.IOException; 29 import java.io.InputStream; 30 import java.io.OutputStream; 31 import java.net.InetAddress; 32 import java.net.InetSocketAddress; 33 import java.net.Socket; 34 import java.net.SocketAddress; 35 import java.net.SocketException; 36 import java.net.SocketOption; 37 import java.net.SocketTimeoutException; 38 import java.net.StandardSocketOptions; 39 import java.nio.ByteBuffer; 40 import java.nio.channels.Channels; 41 import java.nio.channels.ClosedChannelException; 42 import java.nio.channels.IllegalBlockingModeException; 43 import java.nio.channels.SocketChannel; 44 import java.security.AccessController; 45 import java.security.PrivilegedExceptionAction; 46 import sun.nio.ch.ChannelInputStream; 47 import sun.nio.ch.Net; 48 49 import static java.util.concurrent.TimeUnit.*; 50 51 class RdmaSocketAdaptor 52 extends Socket 53 { 54 // The channel being adapted 55 private final RdmaSocketChannelImpl sc; 56 57 // Timeout "option" value for reads 58 private volatile int timeout; 59 60 private RdmaSocketAdaptor(RdmaSocketChannelImpl sc) throws SocketException { 61 super((RdmaSocketImpl) null); 62 this.sc = sc; 63 } 64 65 public static Socket create(RdmaSocketChannelImpl sc) { 66 try { 67 return new RdmaSocketAdaptor(sc); 68 } catch (SocketException e) { 69 throw new InternalError("Should not reach here"); 70 } 71 } 72 73 public SocketChannel getChannel() { 74 return sc; 75 } 76 77 // Override this method just to protect against changes in the superclass 78 // 79 public void connect(SocketAddress remote) throws IOException { 80 connect(remote, 0); 81 } 82 83 public void connect(SocketAddress remote, int timeout) throws IOException { 84 if (remote == null) 85 throw new IllegalArgumentException( 86 "connect: The address can't be null"); 87 if (timeout < 0) 88 throw new IllegalArgumentException( 89 "connect: timeout can't be negative"); 90 91 synchronized (sc.blockingLock()) { 92 if (!sc.isBlocking()) 93 throw new IllegalBlockingModeException(); 94 95 try { 96 if (timeout == 0) { 97 sc.connect(remote); 98 return; 99 } 100 101 sc.configureBlocking(false); 102 try { 103 if (sc.connect(remote)) 104 return; 105 } finally { 106 try { 107 sc.configureBlocking(true); 108 } catch (ClosedChannelException e) { } 109 } 110 111 long timeoutNanos = NANOSECONDS.convert(timeout, MILLISECONDS); 112 long to = timeout; 113 for (;;) { 114 long startTime = System.nanoTime(); 115 if (sc.pollConnected(to)) { 116 boolean connected = sc.finishConnect(); 117 assert connected; 118 break; 119 } 120 timeoutNanos -= System.nanoTime() - startTime; 121 if (timeoutNanos <= 0) { 122 try { 123 sc.close(); 124 } catch (IOException x) { } 125 throw new SocketTimeoutException(); 126 } 127 to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS); 128 } 129 130 } catch (Exception x) { 131 Net.translateException(x, true); 132 } 133 } 134 135 } 136 137 public void bind(SocketAddress local) throws IOException { 138 try { 139 sc.bind(local); 140 } catch (Exception x) { 141 Net.translateException(x); 142 } 143 } 144 145 public InetAddress getInetAddress() { 146 InetSocketAddress remote = sc.remoteAddress(); 147 if (remote == null) { 148 return null; 149 } else { 150 return remote.getAddress(); 151 } 152 } 153 154 public InetAddress getLocalAddress() { 155 if (sc.isOpen()) { 156 InetSocketAddress local = sc.localAddress(); 157 if (local != null) { 158 return Net.getRevealedLocalAddress(local).getAddress(); 159 } 160 } 161 return new InetSocketAddress(0).getAddress(); 162 } 163 164 public int getPort() { 165 InetSocketAddress remote = sc.remoteAddress(); 166 if (remote == null) { 167 return 0; 168 } else { 169 return remote.getPort(); 170 } 171 } 172 173 public int getLocalPort() { 174 InetSocketAddress local = sc.localAddress(); 175 if (local == null) { 176 return -1; 177 } else { 178 return local.getPort(); 179 } 180 } 181 182 private class SocketInputStream 183 extends ChannelInputStream 184 { 185 private SocketInputStream() { 186 super(sc); 187 } 188 189 protected int read(ByteBuffer bb) 190 throws IOException { 191 synchronized (sc.blockingLock()) { 192 if (!sc.isBlocking()) 193 throw new IllegalBlockingModeException(); 194 195 // no timeout 196 long to = RdmaSocketAdaptor.this.timeout; 197 if (to == 0) 198 return sc.read(bb); 199 200 // timed read 201 long timeoutNanos = NANOSECONDS.convert(to, MILLISECONDS); 202 for (;;) { 203 long startTime = System.nanoTime(); 204 if (sc.pollRead(to)) { 205 return sc.read(bb); 206 } 207 timeoutNanos -= System.nanoTime() - startTime; 208 if (timeoutNanos <= 0) 209 throw new SocketTimeoutException(); 210 to = MILLISECONDS.convert(timeoutNanos, NANOSECONDS); 211 } 212 } 213 } 214 } 215 216 private InputStream socketInputStream = null; 217 218 public InputStream getInputStream() throws IOException { 219 if (!sc.isOpen()) 220 throw new SocketException("Socket is closed"); 221 if (!sc.isConnected()) 222 throw new SocketException("Socket is not connected"); 223 if (!sc.isInputOpen()) 224 throw new SocketException("Socket input is shutdown"); 225 if (socketInputStream == null) { 226 try { 227 socketInputStream = AccessController.doPrivileged( 228 new PrivilegedExceptionAction<InputStream>() { 229 public InputStream run() throws IOException { 230 return new SocketInputStream(); 231 } 232 }); 233 } catch (java.security.PrivilegedActionException e) { 234 throw (IOException)e.getException(); 235 } 236 } 237 return socketInputStream; 238 } 239 240 public OutputStream getOutputStream() throws IOException { 241 if (!sc.isOpen()) 242 throw new SocketException("Socket is closed"); 243 if (!sc.isConnected()) 244 throw new SocketException("Socket is not connected"); 245 if (!sc.isOutputOpen()) 246 throw new SocketException("Socket output is shutdown"); 247 OutputStream os = null; 248 try { 249 os = AccessController.doPrivileged( 250 new PrivilegedExceptionAction<OutputStream>() { 251 public OutputStream run() throws IOException { 252 return Channels.newOutputStream(sc); 253 } 254 }); 255 } catch (java.security.PrivilegedActionException e) { 256 throw (IOException)e.getException(); 257 } 258 return os; 259 } 260 261 private void setBooleanOption(SocketOption<Boolean> name, boolean value) 262 throws SocketException { 263 try { 264 sc.setOption(name, value); 265 } catch (IOException x) { 266 Net.translateToSocketException(x); 267 } 268 } 269 270 private void setIntOption(SocketOption<Integer> name, int value) 271 throws SocketException { 272 try { 273 sc.setOption(name, value); 274 } catch (IOException x) { 275 Net.translateToSocketException(x); 276 } 277 } 278 279 private boolean getBooleanOption(SocketOption<Boolean> name) 280 throws SocketException { 281 try { 282 return sc.getOption(name).booleanValue(); 283 } catch (IOException x) { 284 Net.translateToSocketException(x); 285 return false; // keep compiler happy 286 } 287 } 288 289 private int getIntOption(SocketOption<Integer> name) 290 throws SocketException { 291 try { 292 return sc.getOption(name).intValue(); 293 } catch (IOException x) { 294 Net.translateToSocketException(x); 295 return -1; // keep compiler happy 296 } 297 } 298 299 public void setTcpNoDelay(boolean on) throws SocketException { 300 setBooleanOption(StandardSocketOptions.TCP_NODELAY, on); 301 } 302 303 public boolean getTcpNoDelay() throws SocketException { 304 return getBooleanOption(StandardSocketOptions.TCP_NODELAY); 305 } 306 307 public void sendUrgentData(int data) throws IOException { 308 int n = sc.sendOutOfBandData((byte) data); 309 if (n == 0) 310 throw new IOException("Socket buffer full"); 311 } 312 313 public void setSoTimeout(int timeout) throws SocketException { 314 if (timeout < 0) 315 throw new IllegalArgumentException("timeout can't be negative"); 316 this.timeout = timeout; 317 } 318 319 public int getSoTimeout() throws SocketException { 320 return timeout; 321 } 322 323 public void setSendBufferSize(int size) throws SocketException { 324 if (size <= 0) 325 throw new IllegalArgumentException("Invalid send size"); 326 setIntOption(StandardSocketOptions.SO_SNDBUF, size); 327 } 328 329 public int getSendBufferSize() throws SocketException { 330 return getIntOption(StandardSocketOptions.SO_SNDBUF); 331 } 332 333 public void setReceiveBufferSize(int size) throws SocketException { 334 if (size <= 0) 335 throw new IllegalArgumentException("Invalid receive size"); 336 setIntOption(StandardSocketOptions.SO_RCVBUF, size); 337 } 338 339 public int getReceiveBufferSize() throws SocketException { 340 return getIntOption(StandardSocketOptions.SO_RCVBUF); 341 } 342 343 public void setReuseAddress(boolean on) throws SocketException { 344 setBooleanOption(StandardSocketOptions.SO_REUSEADDR, on); 345 } 346 347 public boolean getReuseAddress() throws SocketException { 348 return getBooleanOption(StandardSocketOptions.SO_REUSEADDR); 349 } 350 351 public void close() throws IOException { 352 sc.close(); 353 } 354 355 public void shutdownInput() throws IOException { 356 try { 357 sc.shutdownInput(); 358 } catch (Exception x) { 359 Net.translateException(x); 360 } 361 } 362 363 public void shutdownOutput() throws IOException { 364 try { 365 sc.shutdownOutput(); 366 } catch (Exception x) { 367 Net.translateException(x); 368 } 369 } 370 371 public String toString() { 372 if (sc.isConnected()) 373 return "RdmaSocket[addr=" + getInetAddress() + 374 ",port=" + getPort() + 375 ",localport=" + getLocalPort() + "]"; 376 return "RdmaSocket[unconnected]"; 377 } 378 379 public boolean isConnected() { 380 return sc.isConnected(); 381 } 382 383 public boolean isBound() { 384 return sc.localAddress() != null; 385 } 386 387 public boolean isClosed() { 388 return !sc.isOpen(); 389 } 390 391 public boolean isInputShutdown() { 392 return !sc.isInputOpen(); 393 } 394 395 public boolean isOutputShutdown() { 396 return !sc.isOutputOpen(); 397 } 398 }