< prev index next >
test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java
Print this page
@@ -1,7 +1,7 @@
/*
- * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
@@ -44,17 +44,18 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Function;
+import java.util.function.BiFunction;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.lang.String.format;
import static java.lang.System.err;
import static java.nio.charset.StandardCharsets.ISO_8859_1;
+import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
/**
* Dummy WebSocket Server.
@@ -90,44 +91,73 @@
private volatile ServerSocketChannel ssc;
private volatile InetSocketAddress address;
private ByteBuffer read = ByteBuffer.allocate(16384);
private final CountDownLatch readReady = new CountDownLatch(1);
+ private static class Credentials {
+ private final String name;
+ private final String password;
+ private Credentials(String name, String password) {
+ this.name = name;
+ this.password = password;
+ }
+ public String name() { return name; }
+ public String password() { return password; }
+ }
+
public DummyWebSocketServer() {
- this(defaultMapping());
+ this(defaultMapping(), null, null);
}
- public DummyWebSocketServer(Function<List<String>, List<String>> mapping) {
+ public DummyWebSocketServer(String username, String password) {
+ this(defaultMapping(), username, password);
+ }
+
+ public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping,
+ String username,
+ String password) {
requireNonNull(mapping);
+ Credentials credentials = username != null ?
+ new Credentials(username, password) : null;
+
thread = new Thread(() -> {
try {
while (!Thread.currentThread().isInterrupted()) {
err.println("Accepting next connection at: " + ssc);
SocketChannel channel = ssc.accept();
err.println("Accepted: " + channel);
try {
channel.setOption(StandardSocketOptions.TCP_NODELAY, true);
channel.configureBlocking(true);
+ while (true) {
StringBuilder request = new StringBuilder();
if (!readRequest(channel, request)) {
- throw new IOException("Bad request:" + request);
+ throw new IOException("Bad request:[" + request + "]");
}
List<String> strings = asList(request.toString().split("\r\n"));
- List<String> response = mapping.apply(strings);
+ List<String> response = mapping.apply(strings, credentials);
writeResponse(channel, response);
+
+ if (response.get(0).startsWith("HTTP/1.1 401")) {
+ err.println("Sent 401 Authentication response " + channel);
+ continue;
+ } else {
serve(channel);
+ break;
+ }
+ }
} catch (IOException e) {
err.println("Error in connection: " + channel + ", " + e);
} finally {
err.println("Closed: " + channel);
close(channel);
readReady.countDown();
}
}
} catch (ClosedByInterruptException ignored) {
} catch (Exception e) {
- err.println(e);
+ e.printStackTrace(err);
} finally {
close(ssc);
err.println("Stopped at: " + getURI());
}
});
@@ -254,12 +284,12 @@
while (encoded.hasRemaining()) {
channel.write(encoded);
}
}
- private static Function<List<String>, List<String>> defaultMapping() {
- return request -> {
+ private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() {
+ return (request, credentials) -> {
List<String> response = new LinkedList<>();
Iterator<String> iterator = request.iterator();
if (!iterator.hasNext()) {
throw new IllegalStateException("The request is empty");
}
@@ -307,18 +337,61 @@
}
String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
sha1.update(x.getBytes(ISO_8859_1));
String v = Base64.getEncoder().encodeToString(sha1.digest());
response.add("Sec-WebSocket-Accept: " + v);
+
+ // check authorization credentials, if required by the server
+ if (credentials != null && !authorized(credentials, requestHeaders)) {
+ response.clear();
+ response.add("HTTP/1.1 401 Unauthorized");
+ response.add("Content-Length: 0");
+ response.add("WWW-Authenticate: Basic realm=\"dummy server realm\"");
+ }
+
return response;
};
}
+ // Checks credentials in the request against those allowable by the server.
+ private static boolean authorized(Credentials credentials,
+ Map<String,List<String>> requestHeaders) {
+ List<String> authorization = requestHeaders.get("Authorization");
+ if (authorization == null)
+ return false;
+
+ if (authorization.size() != 1) {
+ throw new IllegalStateException("Authorization unexpected count:" + authorization);
+ }
+ String header = authorization.get(0);
+ if (!header.startsWith("Basic "))
+ throw new IllegalStateException("Authorization not Basic: " + header);
+
+ header = header.substring("Basic ".length());
+ String values = new String(Base64.getDecoder().decode(header), UTF_8);
+ int sep = values.indexOf(':');
+ if (sep < 1) {
+ throw new IllegalStateException("Authorization not colon: " + values);
+ }
+ String name = values.substring(0, sep);
+ String password = values.substring(sep + 1);
+
+ if (name.equals(credentials.name()) && password.equals(credentials.password()))
+ return true;
+
+ return false;
+ }
+
protected static String expectHeader(Map<String, List<String>> headers,
String name,
String value) {
List<String> v = headers.get(name);
+ if (v == null) {
+ throw new IllegalStateException(
+ format("Expected '%s' header, not present in %s",
+ name, headers));
+ }
if (!v.contains(value)) {
throw new IllegalStateException(
format("Expected '%s: %s', actual: '%s: %s'",
name, value, name, v)
);
< prev index next >