< prev index next >

src/java.base/share/classes/sun/nio/ch/DatagramSocketAdaptor.java

Print this page
rev 57619 : [mq]: MulticastSocketAdaptor

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 2001, 2019, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2001, 2020, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.  Oracle designates this

@@ -24,57 +24,67 @@
  */
 
 package sun.nio.ch;
 
 import java.io.IOException;
+import java.lang.invoke.MethodHandle;
 import java.lang.invoke.MethodHandles;
 import java.lang.invoke.MethodHandles.Lookup;
+import java.lang.invoke.MethodType;
 import java.lang.invoke.VarHandle;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
-import java.net.DatagramSocketImpl;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.NetworkInterface;
+import java.net.MulticastSocket;
 import java.net.SocketAddress;
 import java.net.SocketException;
 import java.net.SocketOption;
 import java.net.StandardSocketOptions;
 import java.nio.ByteBuffer;
 import java.nio.channels.AlreadyConnectedException;
 import java.nio.channels.ClosedChannelException;
 import java.nio.channels.DatagramChannel;
+import java.nio.channels.MembershipKey;
 import java.security.AccessController;
 import java.security.PrivilegedAction;
+import java.security.PrivilegedExceptionAction;
+import java.util.Objects;
 import java.util.Set;
+import java.util.concurrent.locks.ReentrantLock;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
 
-// Make a datagram-socket channel look like a datagram socket.
-//
-// The methods in this class are defined in exactly the same order as in
-// java.net.DatagramSocket so as to simplify tracking future changes to that
-// class.
-//
-
-class DatagramSocketAdaptor
-    extends DatagramSocket
+/**
+ * A multicast datagram socket based on a datagram channel.
+ *
+ * This class overrides every public method defined by java.net.DatagramSocket
+ * and java.net.MulticastSocket. The methods in this class are defined in exactly
+ * the same order as in java.net.DatagramSocket and java.net.MulticastSocket so
+ * as to simplify tracking changes.
+ */
+public class DatagramSocketAdaptor
+    extends MulticastSocket
 {
     // The channel being adapted
     private final DatagramChannelImpl dc;
 
     // Timeout "option" value for receives
     private volatile int timeout;
 
-    // create DatagramSocket with useless impl
-    private DatagramSocketAdaptor(DatagramChannelImpl dc) {
-        super(new DummyDatagramSocketImpl());
+    private DatagramSocketAdaptor(DatagramChannelImpl dc) throws IOException {
+        super(/*SocketAddress*/null);
         this.dc = dc;
     }
 
     static DatagramSocket create(DatagramChannelImpl dc) {
+        try {
         return new DatagramSocketAdaptor(dc);
+        } catch (IOException e) {
+            throw new Error(e);
+        }
     }
 
     private void connectInternal(SocketAddress remote) throws SocketException {
         try {
             dc.connect(remote, false); // skips check for already connected

@@ -407,118 +417,267 @@
     @Override
     public Set<SocketOption<?>> supportedOptions() {
         return dc.supportedOptions();
     }
 
+    // -- java.net.MulticastSocket --
 
-    /**
-     * DatagramSocketImpl implementation where all methods throw an error.
-     */
-    private static class DummyDatagramSocketImpl extends DatagramSocketImpl {
-        private static <T> T shouldNotGetHere() {
-            throw new InternalError("Should not get here");
-        }
+    // used to coordinate changing TTL with the deprecated send method
+    private final ReentrantLock sendLock = new ReentrantLock();
 
-        @Override
-        protected void create() {
-            shouldNotGetHere();
-        }
+    // cached outgoing interface (for use by setInterface/getInterface)
+    private final Object outgoingInterfaceLock = new Object();
+    private NetworkInterface outgoingNetworkInterface;
+    private InetAddress outgoingInetAddress;
 
         @Override
-        protected void bind(int lport, InetAddress laddr) {
-            shouldNotGetHere();
+    @Deprecated
+    public void setTTL(byte ttl) throws IOException {
+        setTimeToLive(Byte.toUnsignedInt(ttl));
         }
 
         @Override
-        protected void send(DatagramPacket p) {
-            shouldNotGetHere();
+    public void setTimeToLive(int ttl) throws IOException {
+        sendLock.lock();
+        try {
+            setIntOption(StandardSocketOptions.IP_MULTICAST_TTL, ttl);
+        } finally {
+            sendLock.unlock();
+        }
         }
 
         @Override
-        protected int peek(InetAddress address) {
-            return shouldNotGetHere();
+    @Deprecated
+    public byte getTTL() throws IOException {
+        return (byte) getTimeToLive();
         }
 
         @Override
-        protected int peekData(DatagramPacket p) {
-            return shouldNotGetHere();
+    public int getTimeToLive() throws IOException {
+        sendLock.lock();
+        try {
+            return getIntOption(StandardSocketOptions.IP_MULTICAST_TTL);
+        } finally {
+            sendLock.unlock();
+        }
         }
 
         @Override
-        protected void receive(DatagramPacket p) {
-            shouldNotGetHere();
+    @Deprecated
+    public void joinGroup(InetAddress group) throws IOException {
+        Objects.requireNonNull(group);
+        try {
+            joinGroup(new InetSocketAddress(group, 0), null);
+        } catch (IllegalArgumentException iae) {
+            // 1-arg joinGroup does not specify IllegalArgumentException
+            throw (SocketException) new SocketException("joinGroup failed").initCause(iae);
+        }
         }
 
+    @Override
         @Deprecated
-        protected void setTTL(byte ttl) {
-            shouldNotGetHere();
+    public void leaveGroup(InetAddress group) throws IOException {
+        Objects.requireNonNull(group);
+        try {
+            leaveGroup(new InetSocketAddress(group, 0), null);
+        } catch (IllegalArgumentException iae) {
+            // 1-arg leaveGroup does not specify IllegalArgumentException
+            throw (SocketException) new SocketException("leaveGroup failed").initCause(iae);
+        }
         }
 
-        @Deprecated
-        protected byte getTTL() {
-            return shouldNotGetHere();
+    /**
+     * Checks a SocketAddress to ensure that it is a multicast address.
+     *
+     * @return the multicast group
+     * @throws IllegalArgumentException if group is null, an unsupported address
+     *         type, or an unresolved address
+     * @throws SocketException if group is not a multicast address
+     */
+    private static InetAddress checkGroup(SocketAddress mcastaddr) throws SocketException {
+        if (mcastaddr == null || !(mcastaddr instanceof InetSocketAddress))
+            throw new IllegalArgumentException("Unsupported address type");
+        InetAddress group = ((InetSocketAddress) mcastaddr).getAddress();
+        if (group == null)
+            throw new IllegalArgumentException("Unresolved address");
+        if (!group.isMulticastAddress())
+            throw new SocketException("Not a multicast address");
+        return group;
         }
 
         @Override
-        protected void setTimeToLive(int ttl) {
-            shouldNotGetHere();
+    public void joinGroup(SocketAddress mcastaddr, NetworkInterface netIf) throws IOException {
+        InetAddress group = checkGroup(mcastaddr);
+        NetworkInterface ni = (netIf != null) ? netIf : defaultNetworkInterface();
+        if (isClosed())
+            throw new SocketException("Socket is closed");
+        synchronized (this) {
+            MembershipKey key = dc.findMembership(group, ni);
+            if (key != null) {
+                // already a member but need to check permission anyway
+                SecurityManager sm = System.getSecurityManager();
+                if (sm != null)
+                    sm.checkMulticast(group);
+                throw new SocketException("Already a member of group");
+            }
+            dc.join(group, ni);  // checks permission
+        }
         }
 
         @Override
-        protected int getTimeToLive() {
-            return shouldNotGetHere();
+    public void leaveGroup(SocketAddress mcastaddr, NetworkInterface netIf) throws IOException {
+        InetAddress group = checkGroup(mcastaddr);
+        NetworkInterface ni = (netIf != null) ? netIf : defaultNetworkInterface();
+        if (isClosed())
+            throw new SocketException("Socket is closed");
+        SecurityManager sm = System.getSecurityManager();
+        if (sm != null)
+            sm.checkMulticast(group);
+        synchronized (this) {
+            MembershipKey key = dc.findMembership(group, ni);
+            if (key == null)
+                throw new SocketException("Not a member of group");
+            key.drop();
+        }
         }
 
         @Override
-        protected void join(InetAddress group) {
-            shouldNotGetHere();
+    @Deprecated
+    public void setInterface(InetAddress inf) throws SocketException {
+        if (inf == null)
+            throw new SocketException("Invalid value 'null'");
+        NetworkInterface ni = NetworkInterface.getByInetAddress(inf);
+        if (ni == null) {
+            String address = inf.getHostAddress();
+            throw new SocketException("No network interface with address " + address);
+        }
+        synchronized (outgoingInterfaceLock) {
+            // set interface and update cached values
+            setNetworkInterface(ni);
+            outgoingNetworkInterface = ni;
+            outgoingInetAddress = inf;
+        }
         }
 
         @Override
-        protected void leave(InetAddress inetaddr) {
-            shouldNotGetHere();
+    @Deprecated
+    public InetAddress getInterface() throws SocketException {
+        synchronized (outgoingInterfaceLock) {
+            NetworkInterface ni = outgoingNetworkInterface();
+            if (ni != null) {
+                if (ni.equals(outgoingNetworkInterface)) {
+                    return outgoingInetAddress;
+                } else {
+                    // network interface has changed so update cached values
+                    PrivilegedAction<InetAddress> pa;
+                    pa = () -> ni.inetAddresses().findFirst().orElse(null);
+                    InetAddress ia = AccessController.doPrivileged(pa);
+                    if (ia == null)
+                        throw new SocketException("Network interface has no IP address");
+                    outgoingNetworkInterface = ni;
+                    outgoingInetAddress = ia;
+                    return ia;
+                }
+            }
         }
 
-        @Override
-        protected void joinGroup(SocketAddress group, NetworkInterface netIf) {
-            shouldNotGetHere();
+        // no interface set
+        return anyInetAddress();
         }
 
         @Override
-        protected void leaveGroup(SocketAddress mcastaddr, NetworkInterface netIf) {
-            shouldNotGetHere();
+    public void setNetworkInterface(NetworkInterface netIf) throws SocketException {
+        try {
+            setOption(StandardSocketOptions.IP_MULTICAST_IF, netIf);
+        } catch (IOException e) {
+            Net.translateToSocketException(e);
+        }
         }
 
         @Override
-        protected void close() {
-            shouldNotGetHere();
+    public NetworkInterface getNetworkInterface() throws SocketException {
+        NetworkInterface ni = outgoingNetworkInterface();
+        if (ni == null) {
+            // return NetworkInterface with index == 0 as placeholder
+            ni = anyNetworkInterface();
+        }
+        return ni;
         }
 
         @Override
-        public Object getOption(int optID) {
-            return shouldNotGetHere();
+    @Deprecated
+    public void setLoopbackMode(boolean disable) throws SocketException {
+        boolean enable = !disable;
+        setBooleanOption(StandardSocketOptions.IP_MULTICAST_LOOP, enable);
         }
 
         @Override
-        public void setOption(int optID, Object value) {
-            shouldNotGetHere();
+    @Deprecated
+    public boolean getLoopbackMode() throws SocketException {
+        boolean enabled = getBooleanOption(StandardSocketOptions.IP_MULTICAST_LOOP);
+        return !enabled;
         }
 
         @Override
-        protected <T> void setOption(SocketOption<T> name, T value) {
-            shouldNotGetHere();
+    @Deprecated
+    public void send(DatagramPacket p, byte ttl) throws IOException {
+        sendLock.lock();
+        try {
+            int oldValue = getTimeToLive();
+            try {
+                setTTL(ttl);
+                send(p);
+            } finally {
+                setTimeToLive(oldValue);
+            }
+        } finally {
+            sendLock.unlock();
+        }
         }
 
-        @Override
-        protected <T> T getOption(SocketOption<T> name) {
-            return shouldNotGetHere();
+    /**
+     * Returns the outgoing NetworkInterface or null if not set.
+     */
+    private NetworkInterface outgoingNetworkInterface() throws SocketException {
+        try {
+            return getOption(StandardSocketOptions.IP_MULTICAST_IF);
+        } catch (IOException e) {
+            Net.translateToSocketException(e);
+            return null; // keep compiler happy
+        }
         }
 
-        @Override
-        protected Set<SocketOption<?>> supportedOptions() {
-            return shouldNotGetHere();
+    /**
+     * Returns the default NetworkInterface to use when joining or leaving a
+     * multicast group and a network interface is not specified.
+     * This method will return the outgoing NetworkInterface if set, otherwise
+     * the result of NetworkInterface.getDefault(), otherwise a NetworkInterface
+     * with index == 0 as a placeholder for "any network interface".
+     */
+    private NetworkInterface defaultNetworkInterface() throws SocketException {
+        NetworkInterface ni = outgoingNetworkInterface();
+        if (ni == null)
+            ni = NetworkInterfaces.getDefault();   // macOS
+        if (ni == null)
+            ni = anyNetworkInterface();
+        return ni;
+    }
+
+    /**
+     * Returns the placeholder for "any network interface", its index is 0.
+     */
+    private NetworkInterface anyNetworkInterface() {
+        InetAddress[] addrs = new InetAddress[1];
+        addrs[0] = anyInetAddress();
+        return NetworkInterfaces.newNetworkInterface(addrs[0].getHostName(), 0, addrs);
         }
+
+    /**
+     * Returns the InetAddress representing anyLocalAddress.
+     */
+    private InetAddress anyInetAddress() {
+        return new InetSocketAddress(0).getAddress();
     }
 
     /**
      * Defines static methods to get/set DatagramPacket fields and workaround
      * DatagramPacket deficiencies.

@@ -526,17 +685,12 @@
     private static class DatagramPackets {
         private static final VarHandle LENGTH;
         private static final VarHandle BUF_LENGTH;
         static {
             try {
-                PrivilegedAction<Lookup> pa = () -> {
-                    try {
-                        return MethodHandles.privateLookupIn(DatagramPacket.class, MethodHandles.lookup());
-                    } catch (Exception e) {
-                        throw new ExceptionInInitializerError(e);
-                    }
-                };
+                PrivilegedExceptionAction<Lookup> pa = () ->
+                    MethodHandles.privateLookupIn(DatagramPacket.class, MethodHandles.lookup());
                 MethodHandles.Lookup l = AccessController.doPrivileged(pa);
                 LENGTH = l.findVarHandle(DatagramPacket.class, "length", int.class);
                 BUF_LENGTH = l.findVarHandle(DatagramPacket.class, "bufLength", int.class);
             } catch (Exception e) {
                 throw new ExceptionInInitializerError(e);

@@ -560,6 +714,49 @@
             synchronized (p) {
                 return (int) BUF_LENGTH.get(p);
             }
         }
     }
+
+    /**
+     * Defines static methods to invoke non-public NetworkInterface methods.
+     */
+    private static class NetworkInterfaces {
+        static final MethodHandle GET_DEFAULT;
+        static final MethodHandle CONSTRUCTOR;
+        static {
+            try {
+                PrivilegedExceptionAction<Lookup> pa = () ->
+                    MethodHandles.privateLookupIn(NetworkInterface.class, MethodHandles.lookup());
+                MethodHandles.Lookup l = AccessController.doPrivileged(pa);
+                MethodType methodType = MethodType.methodType(NetworkInterface.class);
+                GET_DEFAULT = l.findStatic(NetworkInterface.class, "getDefault", methodType);
+                methodType = MethodType.methodType(void.class, String.class, int.class, InetAddress[].class);
+                CONSTRUCTOR = l.findConstructor(NetworkInterface.class, methodType);
+            } catch (Exception e) {
+                throw new ExceptionInInitializerError(e);
+            }
+        }
+
+        /**
+         * Returns the default network interface or null.
+         */
+        static NetworkInterface getDefault() {
+            try {
+                return (NetworkInterface) GET_DEFAULT.invokeExact();
+            } catch (Throwable e) {
+                throw new InternalError(e);
+            }
+        }
+
+        /**
+         * Creates a NetworkInterface with the given name index and addresses.
+         */
+        static NetworkInterface newNetworkInterface(String name, int index, InetAddress[] addrs) {
+            try {
+                return (NetworkInterface) CONSTRUCTOR.invoke(name, index, addrs);
+            } catch (Throwable e) {
+                throw new InternalError(e);
+            }
+        }
+    }
 }
\ No newline at end of file
< prev index next >