< prev index next >

test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java

Print this page

        

@@ -1,7 +1,7 @@
 /*
- * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
  * under the terms of the GNU General Public License version 2 only, as
  * published by the Free Software Foundation.

@@ -44,17 +44,18 @@
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Function;
+import java.util.function.BiFunction;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 import static java.lang.String.format;
 import static java.lang.System.err;
 import static java.nio.charset.StandardCharsets.ISO_8859_1;
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.Arrays.asList;
 import static java.util.Objects.requireNonNull;
 
 /**
  * Dummy WebSocket Server.

@@ -90,44 +91,73 @@
     private volatile ServerSocketChannel ssc;
     private volatile InetSocketAddress address;
     private ByteBuffer read = ByteBuffer.allocate(16384);
     private final CountDownLatch readReady = new CountDownLatch(1);
 
+    private static class Credentials {
+        private final String name;
+        private final String password;
+        private Credentials(String name, String password) {
+            this.name = name;
+            this.password = password;
+        }
+        public String name() { return name; }
+        public String password() { return password; }
+    }
+
     public DummyWebSocketServer() {
-        this(defaultMapping());
+        this(defaultMapping(), null, null);
     }
 
-    public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
+    public DummyWebSocketServer(String username, String password) {
+        this(defaultMapping(), username, password);
+    }
+
+    public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
+                                String username,
+                                String password) {
         requireNonNull(mapping);
+        Credentials credentials = username != null ?
+                new Credentials(username, password) : null;
+
         thread = new Thread(() -> {
             try {
                 while (!Thread.currentThread().isInterrupted()) {
                     err.println("Accepting next connection at: " + ssc);
                     SocketChannel channel = ssc.accept();
                     err.println("Accepted: " + channel);
                     try {
                         channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
                         channel.configureBlocking(true);
+                        while (true) {
                         StringBuilder request = new StringBuilder();
                         if (!readRequest(channel, request)) {
-                            throw new IOException("Bad request:" + request);
+                                throw new IOException("Bad request:[" + request + "]");
                         }
                         List<String> strings = asList(request.toString().split("\r\n"));
-                        List<String> response = mapping.apply(strings);
+                            List<String> response = mapping.apply(strings, credentials);
                         writeResponse(channel, response);
+
+                            if (response.get(0).startsWith("HTTP/1.1 401")) {
+                                err.println("Sent 401 Authentication response " + channel);
+                                continue;
+                            } else {
                         serve(channel);
+                                break;
+                            }
+                        }
                     } catch (IOException e) {
                         err.println("Error in connection: " + channel + ", " + e);
                     } finally {
                         err.println("Closed: " + channel);
                         close(channel);
                         readReady.countDown();
                     }
                 }
             } catch (ClosedByInterruptException ignored) {
             } catch (Exception e) {
-                err.println(e);
+                e.printStackTrace(err);
             } finally {
                 close(ssc);
                 err.println("Stopped at: " + getURI());
             }
         });

@@ -254,12 +284,12 @@
         while (encoded.hasRemaining()) {
             channel.write(encoded);
         }
     }
 
-    private static Function<List<String>, List<String>> defaultMapping() {
-        return request -> {
+    private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
+        return (request, credentials) -> {
             List<String> response = new LinkedList<>();
             Iterator<String> iterator = request.iterator();
             if (!iterator.hasNext()) {
                 throw new IllegalStateException("The request is empty");
             }

@@ -307,18 +337,61 @@
             }
             String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
             sha1.update(x.getBytes(ISO_8859_1));
             String v = Base64.getEncoder().encodeToString(sha1.digest());
             response.add("Sec-WebSocket-Accept: " + v);
+
+            // check authorization credentials, if required by the server
+            if (credentials != null && !authorized(credentials, requestHeaders)) {
+                response.clear();
+                response.add("HTTP/1.1 401 Unauthorized");
+                response.add("Content-Length: 0");
+                response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
+            }
+
             return response;
         };
     }
 
+    // Checks credentials in the request against those allowable by the server.
+    private static boolean authorized(Credentials credentials,
+                                      Map<String,List<String>> requestHeaders) {
+        List<String> authorization = requestHeaders.get("Authorization");
+        if (authorization == null)
+            return false;
+
+        if (authorization.size() != 1) {
+            throw new IllegalStateException("Authorization unexpected count:" + authorization);
+        }
+        String header = authorization.get(0);
+        if (!header.startsWith("Basic "))
+            throw new IllegalStateException("Authorization not Basic: " + header);
+
+        header = header.substring("Basic ".length());
+        String values = new String(Base64.getDecoder().decode(header), UTF_8);
+        int sep = values.indexOf(':');
+        if (sep < 1) {
+            throw new IllegalStateException("Authorization not colon: " +  values);
+        }
+        String name = values.substring(0, sep);
+        String password = values.substring(sep + 1);
+
+        if (name.equals(credentials.name()) && password.equals(credentials.password()))
+            return true;
+
+        return false;
+    }
+
     protected static String expectHeader(Map<String, List<String>> headers,
                                          String name,
                                          String value) {
         List<String> v = headers.get(name);
+        if (v == null) {
+            throw new IllegalStateException(
+                    format("Expected '%s' header, not present in %s",
+                           name, headers));
+        }
         if (!v.contains(value)) {
             throw new IllegalStateException(
                     format("Expected '%s: %s', actual: '%s: %s'",
                            name, value, name, v)
             );
< prev index next >