< prev index next >

test/jdk/java/net/httpclient/websocket/DummyWebSocketServer.java

Print this page

        

*** 1,7 **** /* ! * Copyright (c) 2016, 2018, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation. --- 1,7 ---- /* ! * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License version 2 only, as * published by the Free Software Foundation.
*** 44,60 **** import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; ! import java.util.function.Function; import java.util.regex.Pattern; import java.util.stream.Collectors; import static java.lang.String.format; import static java.lang.System.err; import static java.nio.charset.StandardCharsets.ISO_8859_1; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; /** * Dummy WebSocket Server. --- 44,61 ---- import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; ! import java.util.function.BiFunction; import java.util.regex.Pattern; import java.util.stream.Collectors; import static java.lang.String.format; import static java.lang.System.err; import static java.nio.charset.StandardCharsets.ISO_8859_1; + import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.asList; import static java.util.Objects.requireNonNull; /** * Dummy WebSocket Server.
*** 90,133 **** private volatile ServerSocketChannel ssc; private volatile InetSocketAddress address; private ByteBuffer read = ByteBuffer.allocate(16384); private final CountDownLatch readReady = new CountDownLatch(1); public DummyWebSocketServer() { ! this(defaultMapping()); } ! public DummyWebSocketServer(Function<List<String>, List<String>> mapping) { requireNonNull(mapping); thread = new Thread(() -> { try { while (!Thread.currentThread().isInterrupted()) { err.println("Accepting next connection at: " + ssc); SocketChannel channel = ssc.accept(); err.println("Accepted: " + channel); try { channel.setOption(StandardSocketOptions.TCP_NODELAY, true); channel.configureBlocking(true); StringBuilder request = new StringBuilder(); if (!readRequest(channel, request)) { ! throw new IOException("Bad request:" + request); } List<String> strings = asList(request.toString().split("\r\n")); ! List<String> response = mapping.apply(strings); writeResponse(channel, response); serve(channel); } catch (IOException e) { err.println("Error in connection: " + channel + ", " + e); } finally { err.println("Closed: " + channel); close(channel); readReady.countDown(); } } } catch (ClosedByInterruptException ignored) { } catch (Exception e) { ! err.println(e); } finally { close(ssc); err.println("Stopped at: " + getURI()); } }); --- 91,163 ---- private volatile ServerSocketChannel ssc; private volatile InetSocketAddress address; private ByteBuffer read = ByteBuffer.allocate(16384); private final CountDownLatch readReady = new CountDownLatch(1); + private static class Credentials { + private final String name; + private final String password; + private Credentials(String name, String password) { + this.name = name; + this.password = password; + } + public String name() { return name; } + public String password() { return password; } + } + public DummyWebSocketServer() { ! this(defaultMapping(), null, null); } ! public DummyWebSocketServer(String username, String password) { ! this(defaultMapping(), username, password); ! } ! ! public DummyWebSocketServer(BiFunction<List<String>,Credentials,List<String>> mapping, ! String username, ! String password) { requireNonNull(mapping); + Credentials credentials = username != null ? + new Credentials(username, password) : null; + thread = new Thread(() -> { try { while (!Thread.currentThread().isInterrupted()) { err.println("Accepting next connection at: " + ssc); SocketChannel channel = ssc.accept(); err.println("Accepted: " + channel); try { channel.setOption(StandardSocketOptions.TCP_NODELAY, true); channel.configureBlocking(true); + while (true) { StringBuilder request = new StringBuilder(); if (!readRequest(channel, request)) { ! throw new IOException("Bad request:[" + request + "]"); } List<String> strings = asList(request.toString().split("\r\n")); ! List<String> response = mapping.apply(strings, credentials); writeResponse(channel, response); + + if (response.get(0).startsWith("HTTP/1.1 401")) { + err.println("Sent 401 Authentication response " + channel); + continue; + } else { serve(channel); + break; + } + } } catch (IOException e) { err.println("Error in connection: " + channel + ", " + e); } finally { err.println("Closed: " + channel); close(channel); readReady.countDown(); } } } catch (ClosedByInterruptException ignored) { } catch (Exception e) { ! e.printStackTrace(err); } finally { close(ssc); err.println("Stopped at: " + getURI()); } });
*** 254,265 **** while (encoded.hasRemaining()) { channel.write(encoded); } } ! private static Function<List<String>, List<String>> defaultMapping() { ! return request -> { List<String> response = new LinkedList<>(); Iterator<String> iterator = request.iterator(); if (!iterator.hasNext()) { throw new IllegalStateException("The request is empty"); } --- 284,295 ---- while (encoded.hasRemaining()) { channel.write(encoded); } } ! private static BiFunction<List<String>,Credentials,List<String>> defaultMapping() { ! return (request, credentials) -> { List<String> response = new LinkedList<>(); Iterator<String> iterator = request.iterator(); if (!iterator.hasNext()) { throw new IllegalStateException("The request is empty"); }
*** 307,324 **** --- 337,397 ---- } String x = key.get(0) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; sha1.update(x.getBytes(ISO_8859_1)); String v = Base64.getEncoder().encodeToString(sha1.digest()); response.add("Sec-WebSocket-Accept: " + v); + + // check authorization credentials, if required by the server + if (credentials != null && !authorized(credentials, requestHeaders)) { + response.clear(); + response.add("HTTP/1.1 401 Unauthorized"); + response.add("Content-Length: 0"); + response.add("WWW-Authenticate: Basic realm=\"dummy server realm\""); + } + return response; }; } + // Checks credentials in the request against those allowable by the server. + private static boolean authorized(Credentials credentials, + Map<String,List<String>> requestHeaders) { + List<String> authorization = requestHeaders.get("Authorization"); + if (authorization == null) + return false; + + if (authorization.size() != 1) { + throw new IllegalStateException("Authorization unexpected count:" + authorization); + } + String header = authorization.get(0); + if (!header.startsWith("Basic ")) + throw new IllegalStateException("Authorization not Basic: " + header); + + header = header.substring("Basic ".length()); + String values = new String(Base64.getDecoder().decode(header), UTF_8); + int sep = values.indexOf(':'); + if (sep < 1) { + throw new IllegalStateException("Authorization not colon: " + values); + } + String name = values.substring(0, sep); + String password = values.substring(sep + 1); + + if (name.equals(credentials.name()) && password.equals(credentials.password())) + return true; + + return false; + } + protected static String expectHeader(Map<String, List<String>> headers, String name, String value) { List<String> v = headers.get(name); + if (v == null) { + throw new IllegalStateException( + format("Expected '%s' header, not present in %s", + name, headers)); + } if (!v.contains(value)) { throw new IllegalStateException( format("Expected '%s: %s', actual: '%s: %s'", name, value, name, v) );
< prev index next >