< prev index next >

src/java.naming/share/classes/com/sun/jndi/ldap/Connection.java

Print this page

        

@@ -44,13 +44,21 @@
 
 import java.lang.reflect.Method;
 import java.lang.reflect.InvocationTargetException;
 import java.security.AccessController;
 import java.security.PrivilegedAction;
+import java.security.cert.Certificate;
+import java.security.cert.X509Certificate;
 import java.util.Arrays;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import javax.net.SocketFactory;
 import javax.net.ssl.SSLParameters;
+import javax.net.ssl.HandshakeCompletedEvent;
+import javax.net.ssl.HandshakeCompletedListener;
+import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.security.sasl.SaslException;
 
 /**
   * A thread that creates a connection to an LDAP server.
   * After the connection, the thread reads from the connection.
   * A caller can invoke methods on the instance to read LDAP responses

@@ -107,11 +115,11 @@
   *
   * @author Vincent Ryan
   * @author Rosanna Lee
   * @author Jagane Sundar
   */
-public final class Connection implements Runnable {
+public final class Connection implements Runnable, HandshakeCompletedListener {
 
     private static final boolean debug = false;
     private static final int dump = 0; // > 0 r, > 1 rw
 
 

@@ -340,10 +348,11 @@
             if (!IS_HOSTNAME_VERIFICATION_DISABLED) {
                 SSLParameters param = sslSocket.getSSLParameters();
                 param.setEndpointIdentificationAlgorithm("LDAPS");
                 sslSocket.setSSLParameters(param);
             }
+            sslSocket.addHandshakeCompletedListener(this);
             if (connectTimeout > 0) {
                 int socketTimeout = sslSocket.getSoTimeout();
                 sslSocket.setSoTimeout(connectTimeout); // reuse full timeout value
                 sslSocket.startHandshake();
                 sslSocket.setSoTimeout(socketTimeout);

@@ -635,10 +644,19 @@
                         while (ldr != null) {
                             ldr.cancel();
                             ldr = ldr.next;
                         }
                     }
+                    if (isTlsConnection()) {
+                        if (closureReason != null) {
+                            CommunicationException ce = new CommunicationException();
+                            ce.setRootCause(closureReason);
+                            tlsHandshakeCompleted.completeExceptionally(ce);
+                        } else {
+                            tlsHandshakeCompleted.cancel(false);
+                        }
+                    }
                     sock = null;
                 }
                 nparent = notifyParent;
             }
             if (nparent) {

@@ -970,6 +988,48 @@
             }
             nread += count;
         }
         return buf;
     }
+
+    private CompletableFuture<X509Certificate> tlsHandshakeCompleted =
+            new CompletableFuture<>();
+
+    @Override
+    public void handshakeCompleted(HandshakeCompletedEvent event) {
+        try {
+            X509Certificate tlsServerCert = null;
+            Certificate[] certs;
+            if (event.getSocket().getUseClientMode()) {
+                certs = event.getPeerCertificates();
+            } else {
+                certs = event.getLocalCertificates();
+            }
+            if (certs != null && certs.length > 0 &&
+                    certs[0] instanceof X509Certificate) {
+                tlsServerCert = (X509Certificate) certs[0];
+            }
+            tlsHandshakeCompleted.complete(tlsServerCert);
+        } catch (SSLPeerUnverifiedException ex) {
+            CommunicationException ce = new CommunicationException();
+            ce.setRootCause(closureReason);
+            tlsHandshakeCompleted.completeExceptionally(ex);
+        }
+    }
+
+    public boolean isTlsConnection() {
+        return sock instanceof SSLSocket;
+    }
+
+    public X509Certificate getTlsServerCertificate()
+            throws SaslException {
+        try {
+            if (isTlsConnection())
+                return tlsHandshakeCompleted.get();
+        } catch (InterruptedException iex) {
+            throw new SaslException("TLS Handshake Exception ", iex);
+        } catch (ExecutionException eex) {
+            throw new SaslException("TLS Handshake Exception ", eex.getCause());
+        }
+        return null;
+    }
 }
< prev index next >