/* * Copyright (c) 1995, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. Oracle designates this * particular file as subject to the "Classpath" exception as provided * by Oracle in the LICENSE file that accompanied this code. * * This code is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License * version 2 for more details (a copy is included in the LICENSE file that * accompanied this code). * * You should have received a copy of the GNU General Public License version * 2 along with this work; if not, write to the Free Software Foundation, * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. * * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA * or visit www.oracle.com if you need additional information or have any * questions. */ package rdma.ch; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.FileDescriptor; import java.util.Set; import java.util.HashSet; import java.util.Collections; import java.net.Socket; import java.net.ServerSocket; import java.net.SocketImpl; import java.net.SocketOption; import java.net.SocketException; import java.net.UnknownHostException; import java.net.InetAddress; import java.net.SocketAddress; import java.net.InetSocketAddress; import java.net.StandardSocketOptions; import java.net.SocketOptions; import java.lang.reflect.Field; import sun.net.ConnectionResetException; import java.security.AccessController; import java.security.PrivilegedAction; import sun.net.ext.RdmaSocketOptions; public class RdmaSocketImpl extends SocketImpl { Socket socket = null; ServerSocket serverSocket = null; int timeout; // timeout in millisec int trafficClass; private boolean shut_rd = false; private boolean shut_wr = false; private RdmaSocketInputStream socketInputStream = null; private RdmaSocketOutputStream socketOutputStream = null; /* number of threads using the FileDescriptor */ protected int fdUseCount = 0; /* lock when increment/decrementing fdUseCount */ protected final Object fdLock = new Object(); /* indicates a close is pending on the file descriptor */ protected boolean closePending = false; /* indicates connection reset state */ private int CONNECTION_NOT_RESET = 0; private int CONNECTION_RESET_PENDING = 1; private int CONNECTION_RESET = 2; private int resetState; private final Object resetLock = new Object(); /* whether this Socket is a stream (TCP) socket or not (UDP) */ protected boolean stream; static final sun.net.ext.RdmaSocketOptions rdmaOptions = sun.net.ext.RdmaSocketOptions.getInstance(); private static PlatformRdmaSocketImpl platformRdmaSocketImpl = PlatformRdmaSocketImpl.get(); private static Field sCreateState; private static Field sBoundState; private static Field sConnectState; private static Field ssCreateState; private static Field ssBoundState; private static boolean socketStateSet; private static boolean serverSocketStateSet; boolean isRdmaAvailable() { return platformRdmaSocketImpl.isRdmaAvailable(); } @Override protected void setSocket(Socket soc) { this.socket = soc; try { if (!socketStateSet) { sCreateState = Socket.class.getDeclaredField("created"); sCreateState.setAccessible(true); sBoundState = Socket.class.getDeclaredField("bound"); sBoundState.setAccessible(true); sConnectState= Socket.class.getDeclaredField("connected"); sConnectState.setAccessible(true); sCreateState.setBoolean(socket, false); sBoundState.setBoolean(socket, false); sConnectState.setBoolean(socket, false); } } catch (NoSuchFieldException | IllegalAccessException e) { throw new Error(e); } socketStateSet = true; } Socket getSocket() { return socket; } @Override protected void setServerSocket(ServerSocket soc) { this.serverSocket = soc; try { if (!serverSocketStateSet) { ssCreateState = ServerSocket.class.getDeclaredField("created"); ssCreateState.setAccessible(true); ssBoundState = ServerSocket.class.getDeclaredField("bound"); ssBoundState.setAccessible(true); ssCreateState.setBoolean(serverSocket, false); ssBoundState.setBoolean(serverSocket, false); } } catch (NoSuchFieldException | IllegalAccessException e) { throw new Error(e); } serverSocketStateSet = true; } ServerSocket getServerSocket() { return serverSocket; } private static final Set> socketOptions; private static final Set> serverSocketOptions; static { socketOptions = Set.of(StandardSocketOptions.SO_SNDBUF, StandardSocketOptions.SO_RCVBUF, StandardSocketOptions.SO_REUSEADDR, StandardSocketOptions.SO_LINGER, StandardSocketOptions.TCP_NODELAY); serverSocketOptions = Set.of(StandardSocketOptions.SO_RCVBUF, StandardSocketOptions.SO_REUSEADDR); } @Override protected Set> supportedOptions() { Set> options = new HashSet<>(); if (socket != null) options.addAll(socketOptions); else options.addAll(serverSocketOptions); if (isRdmaAvailable()) { RdmaSocketOptions rdmaOptions = RdmaSocketOptions.getInstance(); options.addAll(rdmaOptions.options()); } options = Collections.unmodifiableSet(options); return options; } protected synchronized void create(boolean stream) throws IOException { this.stream = stream; if (stream) { fd = new FileDescriptor(); platformRdmaSocketImpl.rdmaSocketCreate(true, this); } try { if (socket != null) { sCreateState.setBoolean(socket, true); } if (serverSocket != null) ssCreateState.setBoolean(serverSocket, true); } catch (IllegalAccessException e) { throw new AssertionError(e); } } protected void connect(String host, int port) throws UnknownHostException, IOException { boolean connected = false; try { InetAddress address = InetAddress.getByName(host); this.port = port; this.address = address; connectToAddress(address, port, timeout); connected = true; } finally { if (!connected) { try { close(); } catch (IOException ioe) { } } } } protected void connect(InetAddress address, int port) throws IOException { this.port = port; this.address = address; try { connectToAddress(address, port, timeout); return; } catch (IOException e) { close(); throw e; } } protected void connect(SocketAddress address, int timeout) throws IOException { boolean connected = false; try { if (address == null || !(address instanceof InetSocketAddress)) throw new IllegalArgumentException("unsupported address type"); InetSocketAddress addr = (InetSocketAddress) address; if (addr.isUnresolved()) throw new UnknownHostException(addr.getHostName()); this.port = addr.getPort(); this.address = addr.getAddress(); connectToAddress(this.address, port, timeout); connected = true; } finally { if (!connected) { try { close(); } catch (IOException ioe) { } } } } private void connectToAddress(InetAddress address, int port, int timeout) throws IOException { if (address.isAnyLocalAddress()) { doConnect(InetAddress.getLocalHost(), port, timeout); } else { doConnect(address, port, timeout); } } protected void setOption(SocketOption name, T value) throws IOException { if (!rdmaOptions.isOptionSupported(name)) { int opt; if (name == StandardSocketOptions.SO_SNDBUF && socket != null) { if (socket.isConnected()) throw new UnsupportedOperationException( "RDMA socket cannot set SO_SNDBUF after connect."); opt = SocketOptions.SO_SNDBUF; } else if (name == StandardSocketOptions.SO_RCVBUF) { if (socket != null && socket.isConnected()) throw new UnsupportedOperationException( "RDMA socket cannot set SO_RCVBUF after connect."); if (serverSocket != null && serverSocket.isBound()) throw new UnsupportedOperationException( "RDMA server socket cannot set SO_RCVBUF after bind."); opt = SocketOptions.SO_RCVBUF; } else if (name == StandardSocketOptions.SO_REUSEADDR) { opt = SocketOptions.SO_REUSEADDR; } else if (name == StandardSocketOptions.TCP_NODELAY && (socket != null)) { opt = SocketOptions.TCP_NODELAY; } else if (name == StandardSocketOptions.SO_LINGER && (socket != null)) { opt = SocketOptions.SO_LINGER; } else { throw new UnsupportedOperationException("unsupported option"); } setOption(opt, value); } else { rdmaOptions.setOption(fd, name, value); } } @SuppressWarnings("unchecked") protected T getOption(SocketOption name) throws IOException { if (!rdmaOptions.isOptionSupported(name)) { int opt; if (name == StandardSocketOptions.SO_SNDBUF && (socket != null)) { opt = SocketOptions.SO_SNDBUF; } else if (name == StandardSocketOptions.SO_RCVBUF) { opt = SocketOptions.SO_RCVBUF; } else if (name == StandardSocketOptions.SO_REUSEADDR) { opt = SocketOptions.SO_REUSEADDR; } else if (name == StandardSocketOptions.SO_LINGER && (socket != null)) { return (T)getOption(SocketOptions.SO_LINGER); } else if (name == StandardSocketOptions.TCP_NODELAY && (socket != null)) { opt = SocketOptions.TCP_NODELAY; } else { throw new UnsupportedOperationException("unsupported option"); } return (T) getOption(opt); } else { return (T) rdmaOptions.getOption(fd, name); } } public void setOption(int opt, Object val) throws SocketException { if (isClosedOrPending()) { throw new SocketException("Socket Closed"); } boolean on = true; switch (opt) { case SO_LINGER: if (val == null || (!(val instanceof Integer) && !(val instanceof Boolean))) throw new SocketException("Bad parameter for option"); if (val instanceof Boolean) { /* true only if disabling - enabling should be Integer */ on = false; } break; case SO_TIMEOUT: if (val == null || (!(val instanceof Integer))) throw new SocketException("Bad parameter for SO_TIMEOUT"); int tmp = ((Integer) val).intValue(); if (tmp < 0) throw new IllegalArgumentException("timeout < 0"); timeout = tmp; break; case SO_BINDADDR: throw new SocketException("Cannot re-bind socket"); case TCP_NODELAY: if (val == null || !(val instanceof Boolean)) throw new SocketException("bad parameter for TCP_NODELAY"); on = ((Boolean)val).booleanValue(); break; case SO_SNDBUF: case SO_RCVBUF: int value = ((Integer)val).intValue(); int maxValue = 1024 * 1024 * 1024 - 1; //maximum value for the buffer if (val == null || !(val instanceof Integer) || !(value > 0)) { throw new SocketException("bad parameter for SO_SNDBUF " + "or SO_RCVBUF"); } if (value >= maxValue) value = maxValue; break; case SO_REUSEADDR: if (val == null || !(val instanceof Boolean)) throw new SocketException("bad parameter for SO_REUSEADDR"); on = ((Boolean)val).booleanValue(); if (serverSocket != null && serverSocket.isBound()) throw new UnsupportedOperationException( "RDMA server socket cannot set " + "SO_REUSEADDR after bind."); if (socket != null && socket.isConnected()) throw new UnsupportedOperationException( "RDMA socket cannot set " + "SO_REUSEADDR after connect."); break; default: throw new SocketException("unrecognized TCP option: " + opt); } socketSetOption(opt, on, val); } public Object getOption(int opt) throws SocketException { if (isClosedOrPending()) { throw new SocketException("Socket Closed"); } if (opt == SO_TIMEOUT) { return timeout; } int ret = 0; switch (opt) { case TCP_NODELAY: ret = platformRdmaSocketImpl.rdmaSocketGetOption(this, opt, null); return Boolean.valueOf(ret != -1); case SO_LINGER: ret = platformRdmaSocketImpl.rdmaSocketGetOption(this, opt, null); return (ret == -1) ? Boolean.FALSE: (Object)(ret); case SO_REUSEADDR: ret = platformRdmaSocketImpl.rdmaSocketGetOption(this, opt, null); return Boolean.valueOf(ret != -1); case SO_BINDADDR: RdmaInetAddressContainer in = new RdmaInetAddressContainer(); ret = platformRdmaSocketImpl.rdmaSocketGetOption(this, opt, in); return in.addr; case SO_SNDBUF: case SO_RCVBUF: ret = platformRdmaSocketImpl.rdmaSocketGetOption(this, opt, null); return ret; default: return null; } } protected void socketSetOption(int opt, boolean b, Object val) throws SocketException { if (opt == SocketOptions.SO_REUSEPORT && !supportedOptions().contains(StandardSocketOptions.SO_REUSEPORT)) { throw new UnsupportedOperationException("unsupported option"); } try { platformRdmaSocketImpl.rdmaSocketSetOption(this, opt, b, val); } catch (SocketException se) { if (socket == null || !socket.isConnected()) throw se; } } synchronized void doConnect(InetAddress address, int port, int timeout) throws IOException { try { acquireFD(); try { platformRdmaSocketImpl.rdmaSocketConnect(this, address, port, timeout); synchronized (fdLock) { if (closePending) { throw new SocketException ("Socket closed"); } } try { if (socket != null) { sBoundState.setBoolean(socket, true); sConnectState.setBoolean(socket, true); } } catch (IllegalAccessException e) { throw new AssertionError(e); } } finally { releaseFD(); } } catch (IOException e) { close(); throw e; } } protected synchronized void bind(InetAddress address, int lport) throws IOException { platformRdmaSocketImpl.rdmaSocketBind(this, address, lport); try { if (socket != null) sBoundState.setBoolean(socket, true); if (serverSocket != null) ssBoundState.setBoolean(serverSocket, true); } catch (IllegalAccessException e) { throw new AssertionError(e); } } protected synchronized void listen(int count) throws IOException { platformRdmaSocketImpl.rdmaSocketListen(this, count); } protected void accept(SocketImpl s) throws IOException { acquireFD(); try { platformRdmaSocketImpl.rdmaSocketAccept(s, this); } finally { releaseFD(); } } protected synchronized InputStream getInputStream() throws IOException { synchronized (fdLock) { if (isClosedOrPending()) throw new IOException("Socket Closed"); if (shut_rd) throw new IOException("Socket input is shutdown"); if (socketInputStream == null) socketInputStream = new RdmaSocketInputStream(this); } return socketInputStream; } void setInputStream(RdmaSocketInputStream in) { socketInputStream = in; } protected synchronized OutputStream getOutputStream() throws IOException { synchronized (fdLock) { if (isClosedOrPending()) throw new IOException("Socket Closed"); if (shut_wr) throw new IOException("Socket output is shutdown"); if (socketOutputStream == null) socketOutputStream = new RdmaSocketOutputStream(this); } return socketOutputStream; } protected FileDescriptor getFileDescriptor() { return fd; } protected void setFileDescriptor(FileDescriptor fd) { this.fd = fd; } protected void setAddress(InetAddress address) { this.address = address; } void setPort(int port) { this.port = port; } void setLocalPort(int localport) { this.localport = localport; } protected synchronized int available() throws IOException { if (isClosedOrPending()) { throw new IOException("Stream closed."); } if (isConnectionReset() || shut_rd) { return 0; } int n = 0; try { n = platformRdmaSocketImpl.rdmaSocketAvailable(this); if (n == 0 && isConnectionResetPending()) { setConnectionReset(); } } catch (ConnectionResetException exc1) { setConnectionResetPending(); try { n = platformRdmaSocketImpl.rdmaSocketAvailable(this); if (n == 0) { setConnectionReset(); } } catch (ConnectionResetException exc2) { } } return n; } protected void close() throws IOException { synchronized(fdLock) { if (fd != null) { if (fdUseCount == 0) { if (closePending) { return; } closePending = true; try { platformRdmaSocketImpl.rdmaSocketClose(true, this); } finally { platformRdmaSocketImpl.rdmaSocketClose(false, this); } fd = null; return; } else { if (!closePending) { closePending = true; fdUseCount--; platformRdmaSocketImpl.rdmaSocketClose(true, this); } } } } } void reset() throws IOException { if (fd != null) { platformRdmaSocketImpl.rdmaSocketClose(false, this); } fd = null; postReset(); } void postReset() throws IOException { address = null; port = 0; localport = 0; } protected void shutdownInput() throws IOException { if (fd != null) { platformRdmaSocketImpl.rdmaSocketShutdown(SHUT_RD, this); if (socketInputStream != null) { socketInputStream.setEOF(true); } shut_rd = true; } } protected void shutdownOutput() throws IOException { if (fd != null) { platformRdmaSocketImpl.rdmaSocketShutdown(SHUT_WR, this); shut_wr = true; } } protected boolean supportsUrgentData () { return true; } protected void sendUrgentData (int data) throws IOException { if (fd == null) { throw new IOException("Socket Closed"); } platformRdmaSocketImpl.rdmaSocketSendUrgentData(this, data); } FileDescriptor acquireFD() { synchronized (fdLock) { fdUseCount++; return fd; } } void releaseFD() { synchronized (fdLock) { fdUseCount--; if (fdUseCount == -1) { if (fd != null) { try { platformRdmaSocketImpl.rdmaSocketClose(false, this); } catch (IOException e) { } finally { fd = null; } } } } } public boolean isConnectionReset() { synchronized (resetLock) { return (resetState == CONNECTION_RESET); } } public boolean isConnectionResetPending() { synchronized (resetLock) { return (resetState == CONNECTION_RESET_PENDING); } } public void setConnectionReset() { synchronized (resetLock) { resetState = CONNECTION_RESET; } } public void setConnectionResetPending() { synchronized (resetLock) { if (resetState == CONNECTION_NOT_RESET) { resetState = CONNECTION_RESET_PENDING; } } } public boolean isClosedOrPending() { synchronized (fdLock) { if (closePending || (fd == null)) { return true; } else { return false; } } } public int getTimeout() { return timeout; } protected InetAddress getInetAddress() { return address; } protected int getPort() { return port; } protected int getLocalPort() { return localport; } public static final int SHUT_RD = 0; public static final int SHUT_WR = 1; static class PlatformRdmaSocketImpl { @SuppressWarnings("unchecked") private static PlatformRdmaSocketImpl newInstance(String cn) { Class c; try { c = (Class)Class.forName(cn); return c.getConstructor(new Class[] {}).newInstance(); } catch (ReflectiveOperationException x) { throw new AssertionError(x); } } private static PlatformRdmaSocketImpl create() { String osname = AccessController.doPrivileged( new PrivilegedAction() { public String run() { return System.getProperty("os.name"); } }); if ("Linux".equals(osname)) return newInstance("rdma.ch.LinuxRdmaSocketImpl"); return new PlatformRdmaSocketImpl(); } private static final PlatformRdmaSocketImpl instance = create(); static PlatformRdmaSocketImpl get() { return instance; } boolean isRdmaAvailable() { return false; } void rdmaSocketClose(boolean useDeferredClose, RdmaSocketImpl impl) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketCreate(boolean isServer, RdmaSocketImpl impl) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketConnect(RdmaSocketImpl impl, InetAddress address, int port, int timeout) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketBind(RdmaSocketImpl impl, InetAddress address, int port) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketListen(RdmaSocketImpl impl, int count) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketAccept(SocketImpl s, RdmaSocketImpl impl) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } int rdmaSocketAvailable(RdmaSocketImpl impl) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketShutdown(int howto, RdmaSocketImpl impl) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketSetOption(RdmaSocketImpl impl, int cmd, boolean on, Object value) throws SocketException { throw new UnsupportedOperationException("unsupported socket option"); } int rdmaSocketGetOption(RdmaSocketImpl impl, int opt, Object iaContainerObj) throws SocketException { throw new UnsupportedOperationException("unsupported socket option"); } void rdmaSocketSendUrgentData(RdmaSocketImpl impl, int data) throws IOException { throw new UnsupportedOperationException("unsupported socket option"); } } }