1 /*
   2  * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 /*
  25  * @test
  26  * @bug 8217429
  27  * @summary WebSocket proxy tunneling tests
  28  * @compile DummyWebSocketServer.java ../ProxyServer.java
  29  * @run testng/othervm
  30  *         -Djdk.http.auth.tunneling.disabledSchemes=
  31  *         WebSocketProxyTest
  32  */
  33 
  34 import java.io.IOException;
  35 import java.io.UncheckedIOException;
  36 import java.net.Authenticator;
  37 import java.net.InetAddress;
  38 import java.net.InetSocketAddress;
  39 import java.net.PasswordAuthentication;
  40 import java.net.ProxySelector;
  41 import java.net.http.HttpResponse;
  42 import java.net.http.WebSocket;
  43 import java.net.http.WebSocketHandshakeException;
  44 import java.nio.ByteBuffer;
  45 import java.nio.charset.StandardCharsets;
  46 import java.util.ArrayList;
  47 import java.util.Base64;
  48 import java.util.List;
  49 import java.util.concurrent.CompletableFuture;
  50 import java.util.concurrent.CompletionException;
  51 import java.util.concurrent.CompletionStage;
  52 import java.util.function.Function;
  53 import java.util.function.Supplier;
  54 import java.util.stream.Collectors;
  55 import org.testng.annotations.DataProvider;
  56 import org.testng.annotations.Test;
  57 import static java.net.http.HttpClient.newBuilder;
  58 import static java.nio.charset.StandardCharsets.UTF_8;
  59 import static org.testng.Assert.assertEquals;
  60 import static org.testng.FileAssert.fail;
  61 
  62 public class WebSocketProxyTest {
  63 
  64     // Used to verify a proxy/websocket server requiring Authentication
  65     private static final String USERNAME = "wally";
  66     private static final String PASSWORD = "xyz987";
  67 
  68     static class WSAuthenticator extends Authenticator {
  69         @Override
  70         protected PasswordAuthentication getPasswordAuthentication() {
  71             return new PasswordAuthentication(USERNAME, PASSWORD.toCharArray());
  72         }
  73     }
  74 
  75     static final Function<int[],DummyWebSocketServer> SERVER_WITH_CANNED_DATA =
  76         new Function<>() {
  77             @Override public DummyWebSocketServer apply(int[] data) {
  78                 return Support.serverWithCannedData(data); }
  79             @Override public String toString() { return "SERVER_WITH_CANNED_DATA"; }
  80         };
  81 
  82     static final Function<int[],DummyWebSocketServer> AUTH_SERVER_WITH_CANNED_DATA =
  83         new Function<>() {
  84             @Override public DummyWebSocketServer apply(int[] data) {
  85                 return Support.serverWithCannedDataAndAuthentication(USERNAME, PASSWORD, data); }
  86             @Override public String toString() { return "AUTH_SERVER_WITH_CANNED_DATA"; }
  87         };
  88 
  89     static final Supplier<ProxyServer> TUNNELING_PROXY_SERVER =
  90         new Supplier<>() {
  91             @Override public ProxyServer get() {
  92                 try { return new ProxyServer(0, true);}
  93                 catch(IOException e) { throw new UncheckedIOException(e); } }
  94             @Override public String toString() { return "TUNNELING_PROXY_SERVER"; }
  95         };
  96     static final Supplier<ProxyServer> AUTH_TUNNELING_PROXY_SERVER =
  97         new Supplier<>() {
  98             @Override public ProxyServer get() {
  99                 try { return new ProxyServer(0, true, USERNAME, PASSWORD);}
 100                 catch(IOException e) { throw new UncheckedIOException(e); } }
 101             @Override public String toString() { return "AUTH_TUNNELING_PROXY_SERVER"; }
 102         };
 103 
 104     @DataProvider(name = "servers")
 105     public Object[][] servers() {
 106         return new Object[][] {
 107             { SERVER_WITH_CANNED_DATA,      TUNNELING_PROXY_SERVER      },
 108             { SERVER_WITH_CANNED_DATA,      AUTH_TUNNELING_PROXY_SERVER },
 109             { AUTH_SERVER_WITH_CANNED_DATA, TUNNELING_PROXY_SERVER      },
 110         };
 111     }
 112 
 113     @Test(dataProvider = "servers")
 114     public void simpleAggregatingBinaryMessages
 115             (Function<int[],DummyWebSocketServer> serverSupplier,
 116              Supplier<ProxyServer> proxyServerSupplier)
 117         throws IOException
 118     {
 119         List<byte[]> expected = List.of("hello", "chegar")
 120                 .stream()
 121                 .map(s -> s.getBytes(StandardCharsets.US_ASCII))
 122                 .collect(Collectors.toList());
 123         int[] binary = new int[]{
 124                 0x82, 0x05, 0x68, 0x65, 0x6C, 0x6C, 0x6F,       // hello
 125                 0x82, 0x06, 0x63, 0x68, 0x65, 0x67, 0x61, 0x72, // chegar
 126                 0x88, 0x00                                      // <CLOSE>
 127         };
 128         CompletableFuture<List<byte[]>> actual = new CompletableFuture<>();
 129 
 130         try (var proxyServer = proxyServerSupplier.get();
 131              var server = serverSupplier.apply(binary)) {
 132 
 133             InetSocketAddress proxyAddress = new InetSocketAddress(
 134                     InetAddress.getLoopbackAddress(), proxyServer.getPort());
 135             server.open();
 136 
 137             WebSocket.Listener listener = new WebSocket.Listener() {
 138 
 139                 List<byte[]> collectedBytes = new ArrayList<>();
 140                 ByteBuffer buffer = ByteBuffer.allocate(1024);
 141 
 142                 @Override
 143                 public CompletionStage<?> onBinary(WebSocket webSocket,
 144                                                    ByteBuffer message,
 145                                                    boolean last) {
 146                     System.out.printf("onBinary(%s, %s)%n", message, last);
 147                     webSocket.request(1);
 148 
 149                     append(message);
 150                     if (last) {
 151                         buffer.flip();
 152                         byte[] bytes = new byte[buffer.remaining()];
 153                         buffer.get(bytes);
 154                         buffer.clear();
 155                         processWholeBinary(bytes);
 156                     }
 157                     return null;
 158                 }
 159 
 160                 private void append(ByteBuffer message) {
 161                     if (buffer.remaining() < message.remaining()) {
 162                         assert message.remaining() > 0;
 163                         int cap = (buffer.capacity() + message.remaining()) * 2;
 164                         ByteBuffer b = ByteBuffer.allocate(cap);
 165                         b.put(buffer.flip());
 166                         buffer = b;
 167                     }
 168                     buffer.put(message);
 169                 }
 170 
 171                 private void processWholeBinary(byte[] bytes) {
 172                     String stringBytes = new String(bytes, UTF_8);
 173                     System.out.println("processWholeBinary: " + stringBytes);
 174                     collectedBytes.add(bytes);
 175                 }
 176 
 177                 @Override
 178                 public CompletionStage<?> onClose(WebSocket webSocket,
 179                                                   int statusCode,
 180                                                   String reason) {
 181                     actual.complete(collectedBytes);
 182                     return null;
 183                 }
 184 
 185                 @Override
 186                 public void onError(WebSocket webSocket, Throwable error) {
 187                     actual.completeExceptionally(error);
 188                 }
 189             };
 190 
 191             var webSocket = newBuilder()
 192                     .proxy(ProxySelector.of(proxyAddress))
 193                     .authenticator(new WSAuthenticator())
 194                     .build().newWebSocketBuilder()
 195                     .buildAsync(server.getURI(), listener)
 196                     .join();
 197 
 198             List<byte[]> a = actual.join();
 199             assertEquals(a, expected);
 200         }
 201     }
 202 
 203     // -- authentication specific tests
 204 
 205     /*
 206      * Ensures authentication succeeds when an Authenticator set on client builder.
 207      */
 208     @Test
 209     public void clientAuthenticate() throws IOException  {
 210         try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
 211              var server = new DummyWebSocketServer()){
 212             server.open();
 213             InetSocketAddress proxyAddress = new InetSocketAddress(
 214                     InetAddress.getLoopbackAddress(), proxyServer.getPort());
 215 
 216             var webSocket = newBuilder()
 217                     .proxy(ProxySelector.of(proxyAddress))
 218                     .authenticator(new WSAuthenticator())
 219                     .build()
 220                     .newWebSocketBuilder()
 221                     .buildAsync(server.getURI(), new WebSocket.Listener() { })
 222                     .join();
 223         }
 224     }
 225 
 226     /*
 227      * Ensures authentication succeeds when an `Authorization` header is explicitly set.
 228      */
 229     @Test
 230     public void explicitAuthenticate() throws IOException  {
 231         try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
 232              var server = new DummyWebSocketServer()) {
 233             server.open();
 234             InetSocketAddress proxyAddress = new InetSocketAddress(
 235                     InetAddress.getLoopbackAddress(), proxyServer.getPort());
 236 
 237             String hv = "Basic " + Base64.getEncoder().encodeToString(
 238                     (USERNAME + ":" + PASSWORD).getBytes(UTF_8));
 239 
 240             var webSocket = newBuilder()
 241                     .proxy(ProxySelector.of(proxyAddress)).build()
 242                     .newWebSocketBuilder()
 243                     .header("Proxy-Authorization", hv)
 244                     .buildAsync(server.getURI(), new WebSocket.Listener() { })
 245                     .join();
 246         }
 247     }
 248 
 249     /*
 250      * Ensures authentication does not succeed when no authenticator is present.
 251      */
 252     @Test
 253     public void failNoAuthenticator() throws IOException  {
 254         try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
 255              var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
 256             server.open();
 257             InetSocketAddress proxyAddress = new InetSocketAddress(
 258                     InetAddress.getLoopbackAddress(), proxyServer.getPort());
 259 
 260             CompletableFuture<WebSocket> cf = newBuilder()
 261                     .proxy(ProxySelector.of(proxyAddress)).build()
 262                     .newWebSocketBuilder()
 263                     .buildAsync(server.getURI(), new WebSocket.Listener() { });
 264 
 265             try {
 266                 var webSocket = cf.join();
 267                 fail("Expected exception not thrown");
 268             } catch (CompletionException expected) {
 269                 WebSocketHandshakeException e = (WebSocketHandshakeException)expected.getCause();
 270                 HttpResponse<?> response = e.getResponse();
 271                 assertEquals(response.statusCode(), 407);
 272             }
 273         }
 274     }
 275 
 276     /*
 277      * Ensures authentication does not succeed when the authenticator presents
 278      * unauthorized credentials.
 279      */
 280     @Test
 281     public void failBadCredentials() throws IOException  {
 282         try (var proxyServer = AUTH_TUNNELING_PROXY_SERVER.get();
 283              var server = new DummyWebSocketServer(USERNAME, PASSWORD)) {
 284             server.open();
 285             InetSocketAddress proxyAddress = new InetSocketAddress(
 286                     InetAddress.getLoopbackAddress(), proxyServer.getPort());
 287 
 288             Authenticator authenticator = new Authenticator() {
 289                 @Override protected PasswordAuthentication getPasswordAuthentication() {
 290                     return new PasswordAuthentication("BAD"+USERNAME, "".toCharArray());
 291                 }
 292             };
 293 
 294             CompletableFuture<WebSocket> cf = newBuilder()
 295                     .proxy(ProxySelector.of(proxyAddress))
 296                     .authenticator(authenticator)
 297                     .build()
 298                     .newWebSocketBuilder()
 299                     .buildAsync(server.getURI(), new WebSocket.Listener() { });
 300 
 301             try {
 302                 var webSocket = cf.join();
 303                 fail("Expected exception not thrown");
 304             } catch (CompletionException expected) {
 305                 System.out.println("caught expected exception:" + expected);
 306             }
 307         }
 308     }
 309 }