1 /*
   2  * Copyright (c) 2003, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.
   8  *
   9  * This code is distributed in the hope that it will be useful, but WITHOUT
  10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  12  * version 2 for more details (a copy is included in the LICENSE file that
  13  * accompanied this code).
  14  *
  15  * You should have received a copy of the GNU General Public License version
  16  * 2 along with this work; if not, write to the Free Software Foundation,
  17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  18  *
  19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  20  * or visit www.oracle.com if you need additional information or have any
  21  * questions.
  22  */
  23 
  24 import java.io.*;
  25 import java.net.*;
  26 import java.util.*;
  27 import sun.net.www.protocol.http.*;
  28 
  29 public class B4933582 implements HttpCallback {
  30 
  31     static int count = 0;
  32     static String authstring;
  33 
  34     void errorReply (HttpTransaction req, String reply) throws IOException {
  35         req.addResponseHeader ("Connection", "close");
  36         req.addResponseHeader ("WWW-Authenticate", reply);
  37         req.sendResponse (401, "Unauthorized");
  38         req.orderlyClose();
  39     }
  40 
  41     void okReply (HttpTransaction req) throws IOException {
  42         req.setResponseEntityBody ("Hello .");
  43         req.sendResponse (200, "Ok");
  44         req.orderlyClose();
  45     }
  46 
  47     static boolean firstTime = true;
  48 
  49     public void request (HttpTransaction req) {
  50         try {
  51             authstring = req.getRequestHeader ("Authorization");
  52             if (firstTime) {
  53                 switch (count) {
  54                 case 0:
  55                     errorReply (req, "Basic realm=\"wallyworld\"");
  56                     break;
  57                 case 1:
  58                     /* client stores a username/pw for wallyworld
  59                      */
  60                     save (authstring);
  61                     okReply (req);
  62                     break;
  63                 }
  64             } else {
  65                 /* check the auth string is premptively set from last time */
  66                 String savedauth = retrieve();
  67                 if (savedauth.equals (authstring)) {
  68                     okReply (req);
  69                 } else {
  70                     System.out.println ("savedauth = " + savedauth);
  71                     System.out.println ("authstring = " + authstring);
  72                     errorReply (req, "Basic realm=\"wallyworld\"");
  73                 }
  74             }
  75             count ++;
  76         } catch (IOException e) {
  77             e.printStackTrace();
  78         }
  79     }
  80 
  81     void save (String s) {
  82         try {
  83             FileOutputStream f = new FileOutputStream ("auth.save");
  84             ObjectOutputStream os = new ObjectOutputStream (f);
  85             os.writeObject (s);
  86         } catch (IOException e) {
  87             assert false;
  88         }
  89     }
  90 
  91     String retrieve () {
  92         String s = null;
  93         try {
  94             FileInputStream f = new FileInputStream ("auth.save");
  95             ObjectInputStream is = new ObjectInputStream (f);
  96             s = (String) is.readObject();
  97         } catch (Exception e) {
  98             assert false;
  99         }
 100         return s;
 101     }
 102 
 103     static void read (InputStream is) throws IOException {
 104         int c;
 105         System.out.println ("reading");
 106         while ((c=is.read()) != -1) {
 107             System.out.write (c);
 108         }
 109         System.out.println ("");
 110         System.out.println ("finished reading");
 111     }
 112 
 113     static void client (String u) throws Exception {
 114         URL url = new URL (u);
 115         System.out.println ("client opening connection to: " + u);
 116         URLConnection urlc = url.openConnection ();
 117         InputStream is = urlc.getInputStream ();
 118         read (is);
 119         is.close();
 120     }
 121 
 122     static HttpServer server;
 123 
 124     public static void main (String[] args) throws Exception {
 125         firstTime = args[0].equals ("first");
 126         MyAuthenticator auth = new MyAuthenticator ();
 127         Authenticator.setDefault (auth);
 128         CacheImpl cache;
 129         try {
 130             if (firstTime) {
 131                 server = new HttpServer (new B4933582(), 1, 10, 0);
 132                 cache = new CacheImpl (server.getLocalPort());
 133             } else {
 134                 cache = new CacheImpl ();
 135                 server = new HttpServer(new B4933582(), 1, 10, cache.getPort());
 136             }
 137             AuthCacheValue.setAuthCache (cache);
 138             System.out.println ("Server: listening on port: " + server.getLocalPort());
 139             client ("http://localhost:"+server.getLocalPort()+"/d1/foo.html");
 140         } catch (Exception e) {
 141             if (server != null) {
 142                 server.terminate();
 143             }
 144             throw e;
 145         }
 146         int f = auth.getCount();
 147         if (firstTime && f != 1) {
 148             except ("Authenticator was called "+f+" times. Should be 1");
 149         }
 150         if (!firstTime && f != 0) {
 151             except ("Authenticator was called "+f+" times. Should be 0");
 152         }
 153         server.terminate();
 154     }
 155 
 156     public static void except (String s) {
 157         server.terminate();
 158         throw new RuntimeException (s);
 159     }
 160 
 161     static class MyAuthenticator extends Authenticator {
 162         MyAuthenticator () {
 163             super ();
 164         }
 165 
 166         int count = 0;
 167 
 168         public PasswordAuthentication getPasswordAuthentication () {
 169             PasswordAuthentication pw;
 170             pw = new PasswordAuthentication ("user", "pass1".toCharArray());
 171             count ++;
 172             return pw;
 173         }
 174 
 175         public int getCount () {
 176             return (count);
 177         }
 178     }
 179 
 180     static class CacheImpl extends AuthCacheImpl {
 181         HashMap map;
 182         int port; // need to store the port number the server is using
 183 
 184         CacheImpl () throws IOException {
 185             this (-1);
 186         }
 187 
 188         CacheImpl (int port) throws IOException {
 189             super();
 190             this.port = port;
 191             File src = new File ("cache.ser");
 192             if (src.exists()) {
 193                 ObjectInputStream is = new ObjectInputStream (
 194                     new FileInputStream (src)
 195                 );
 196                 try {
 197                     map = (HashMap)is.readObject ();
 198                     this.port = (Integer)is.readObject ();
 199                     System.out.println ("read port from file " + port);
 200                 } catch (ClassNotFoundException e) {
 201                     assert false;
 202                 }
 203                 is.close();
 204                 System.out.println ("setMap from cache.ser");
 205             } else {
 206                 map = new HashMap();
 207             }
 208             setMap (map);
 209         }
 210 
 211         int getPort () {
 212             return port;
 213         }
 214 
 215         private void writeMap () {
 216             try {
 217                 File dst = new File ("cache.ser");
 218                 dst.delete();
 219                 if (!dst.createNewFile()) {
 220                     return;
 221                 }
 222                 ObjectOutputStream os = new ObjectOutputStream (
 223                         new FileOutputStream (dst)
 224                 );
 225                 os.writeObject(map);
 226                 os.writeObject(port);
 227                 System.out.println ("wrote port " + port);
 228                 os.close();
 229             } catch (IOException e) {}
 230         }
 231 
 232         public void put (String pkey, AuthCacheValue value) {
 233             System.out.println ("put: " + pkey + " " + value);
 234             super.put (pkey, value);
 235             writeMap();
 236         }
 237 
 238         public AuthCacheValue get (String pkey, String skey) {
 239             System.out.println ("get: " + pkey + " " + skey);
 240             AuthCacheValue i = super.get (pkey, skey);
 241             System.out.println ("---> " + i);
 242             return i;
 243         }
 244 
 245         public void remove (String pkey, AuthCacheValue value) {
 246             System.out.println ("remove: " + pkey + " " + value);
 247             super.remove (pkey, value);
 248             writeMap();
 249         }
 250     }
 251 }