--- old/test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java 2019-01-25 16:52:01.000000000 +0000 +++ new/test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java 2019-01-25 16:52:01.000000000 +0000 @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -46,13 +46,14 @@ import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; +import java.util.function.BiFunction; import java.util.regex.Pattern; import java.util.stream.Collectors; import static java.lang.String.format; import static java.lang.System.err; import static java.nio.charset.StandardCharsets.ISO_8859_1; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; @@ -92,12 +93,32 @@ private ByteBuffer read = ByteBuffer.allocate(16384); private final CountDownLatch readReady = new CountDownLatch(1); + private static class Credentials { + private final String name; + private final String password; + private Credentials(String name, String password) { + this.name = name; + this.password = password; + } + public String name() { return name; } + public String password() { return password; } + } + public DummyWebSocketServer() { - this(defaultMapping()); + this(defaultMapping(), null, null); + } + + public DummyWebSocketServer(String username, String password) { + this(defaultMapping(), username, password); } - public DummyWebSocketServer(Function, List> mapping) { + public DummyWebSocketServer(BiFunction,Credentials,List> mapping, + String username, + String password) { requireNonNull(mapping); + Credentials credentials = username != null ? + new Credentials(username, password) : null; + thread = new Thread(() -> { try { while (!Thread.currentThread().isInterrupted()) { @@ -107,14 +128,23 @@ try { channel.setOption(StandardSocketOptions.TCP_NODELAY, true); channel.configureBlocking(true); - StringBuilder request = new StringBuilder(); - if (!readRequest(channel, request)) { - throw new IOException("Bad request:" + request); + while (true) { + StringBuilder request = new StringBuilder(); + if (!readRequest(channel, request)) { + throw new IOException("Bad request:[" + request + "]"); + } + List strings = asList(request.toString().split("\r\n")); + List response = mapping.apply(strings, credentials); + writeResponse(channel, response); + + if (response.get(0).startsWith("HTTP/1.1 401")) { + err.println("Sent 401 Authentication response " + channel); + continue; + } else { + serve(channel); + break; + } } - List strings = asList(request.toString().split("\r\n")); - List response = mapping.apply(strings); - writeResponse(channel, response); - serve(channel); } catch (IOException e) { err.println("Error in connection: " + channel + ", " + e); } finally { @@ -125,7 +155,7 @@ } } catch (ClosedByInterruptException ignored) { } catch (Exception e) { - err.println(e); + e.printStackTrace(err); } finally { close(ssc); err.println("Stopped at: " + getURI()); @@ -256,8 +286,8 @@ } } - private static Function, List> defaultMapping() { - return request -> { + private static BiFunction,Credentials,List> defaultMapping() { + return (request, credentials) -> { List response = new LinkedList<>(); Iterator iterator = request.iterator(); if (!iterator.hasNext()) { @@ -309,14 +339,57 @@ sha1.update(x.getBytes(ISO_8859_1)); String v = Base64.getEncoder().encodeToString(sha1.digest()); response.add("Sec-WebSocket-Accept: " + v); + + // check authorization credentials, if required by the server + if (credentials != null && !authorized(credentials, requestHeaders)) { + response.clear(); + response.add("HTTP/1.1 401 Unauthorized"); + response.add("Content-Length: 0"); + response.add("WWW-Authenticate: Basic realm=\"dummy server realm\""); + } + return response; }; } + // Checks credentials in the request against those allowable by the server. + private static boolean authorized(Credentials credentials, + Map> requestHeaders) { + List authorization = requestHeaders.get("Authorization"); + if (authorization == null) + return false; + + if (authorization.size() != 1) { + throw new IllegalStateException("Authorization unexpected count:" + authorization); + } + String header = authorization.get(0); + if (!header.startsWith("Basic ")) + throw new IllegalStateException("Authorization not Basic: " + header); + + header = header.substring("Basic ".length()); + String values = new String(Base64.getDecoder().decode(header), UTF_8); + int sep = values.indexOf(':'); + if (sep < 1) { + throw new IllegalStateException("Authorization not colon: " + values); + } + String name = values.substring(0, sep); + String password = values.substring(sep + 1); + + if (name.equals(credentials.name()) && password.equals(credentials.password())) + return true; + + return false; + } + protected static String expectHeader(Map> headers, String name, String value) { List v = headers.get(name); + if (v == null) { + throw new IllegalStateException( + format("Expected '%s' header, not present in %s", + name, headers)); + } if (!v.contains(value)) { throw new IllegalStateException( format("Expected '%s: %s', actual: '%s: %s'",