1 /*
   2  * Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package jdk.internal.jshell.remote;
  27 import jdk.jshell.spi.SPIResolutionException;
  28 import java.io.File;
  29 import java.io.IOException;
  30 import java.io.ObjectInputStream;
  31 import java.io.ObjectOutputStream;
  32 import java.io.OutputStream;
  33 import java.io.PrintStream;
  34 import java.io.UnsupportedEncodingException;
  35 import java.lang.reflect.Field;
  36 import java.lang.reflect.InvocationTargetException;
  37 import java.lang.reflect.Method;
  38 import java.net.Socket;
  39 
  40 import java.util.ArrayList;
  41 import java.util.List;
  42 
  43 import static jdk.internal.jshell.remote.RemoteCodes.*;
  44 
  45 import java.util.Map;
  46 import java.util.TreeMap;
  47 
  48 /**
  49  * The remote agent runs in the execution process (separate from the main JShell
  50  * process.  This agent loads code over a socket from the main JShell process,
  51  * executes the code, and other misc,
  52  * @author Robert Field
  53  */
  54 class RemoteAgent {
  55 
  56     private final RemoteClassLoader loader = new RemoteClassLoader();
  57     private final Map<String, Class<?>> klasses = new TreeMap<>();
  58 
  59     public static void main(String[] args) throws Exception {
  60         String loopBack = null;
  61         Socket socket = new Socket(loopBack, Integer.parseInt(args[0]));
  62         (new RemoteAgent()).commandLoop(socket);
  63     }
  64 
  65     void commandLoop(Socket socket) throws IOException {
  66         // in before out -- so we don't hang the controlling process
  67         ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
  68         OutputStream socketOut = socket.getOutputStream();
  69         System.setOut(new PrintStream(new MultiplexingOutputStream("out", socketOut), true));
  70         System.setErr(new PrintStream(new MultiplexingOutputStream("err", socketOut), true));
  71         ObjectOutputStream out = new ObjectOutputStream(new MultiplexingOutputStream("command", socketOut));
  72         while (true) {
  73             int cmd = in.readInt();
  74             switch (cmd) {
  75                 case CMD_EXIT:
  76                     // Terminate this process
  77                     return;
  78                 case CMD_LOAD:
  79                     // Load a generated class file over the wire
  80                     try {
  81                         int count = in.readInt();
  82                         List<String> names = new ArrayList<>(count);
  83                         for (int i = 0; i < count; ++i) {
  84                             String name = in.readUTF();
  85                             byte[] kb = (byte[]) in.readObject();
  86                             loader.delare(name, kb);
  87                             names.add(name);
  88                         }
  89                         for (String name : names) {
  90                             Class<?> klass = loader.loadClass(name);
  91                             klasses.put(name, klass);
  92                             // Get class loaded to the point of, at least, preparation
  93                             klass.getDeclaredMethods();
  94                         }
  95                         out.writeInt(RESULT_SUCCESS);
  96                         out.flush();
  97                     } catch (IOException | ClassNotFoundException | ClassCastException ex) {
  98                         debug("*** Load failure: %s\n", ex);
  99                         out.writeInt(RESULT_FAIL);
 100                         out.writeUTF(ex.toString());
 101                         out.flush();
 102                     }
 103                     break;
 104                 case CMD_INVOKE: {
 105                     // Invoke executable entry point in loaded code
 106                     String name = in.readUTF();
 107                     Class<?> klass = klasses.get(name);
 108                     if (klass == null) {
 109                         debug("*** Invoke failure: no such class loaded %s\n", name);
 110                         out.writeInt(RESULT_FAIL);
 111                         out.writeUTF("no such class loaded: " + name);
 112                         out.flush();
 113                         break;
 114                     }
 115                     String methodName = in.readUTF();
 116                     Method doitMethod;
 117                     try {
 118                         this.getClass().getModule().addExports(SPIResolutionException.class.getPackage().getName(), klass.getModule());
 119                         doitMethod = klass.getDeclaredMethod(methodName, new Class<?>[0]);
 120                         doitMethod.setAccessible(true);
 121                         Object res;
 122                         try {
 123                             clientCodeEnter();
 124                             res = doitMethod.invoke(null, new Object[0]);
 125                         } catch (InvocationTargetException ex) {
 126                             if (ex.getCause() instanceof StopExecutionException) {
 127                                 expectingStop = false;
 128                                 throw (StopExecutionException) ex.getCause();
 129                             }
 130                             throw ex;
 131                         } catch (StopExecutionException ex) {
 132                             expectingStop = false;
 133                             throw ex;
 134                         } finally {
 135                             clientCodeLeave();
 136                         }
 137                         out.writeInt(RESULT_SUCCESS);
 138                         out.writeUTF(valueString(res));
 139                         out.flush();
 140                     } catch (InvocationTargetException ex) {
 141                         Throwable cause = ex.getCause();
 142                         StackTraceElement[] elems = cause.getStackTrace();
 143                         if (cause instanceof SPIResolutionException) {
 144                             out.writeInt(RESULT_CORRALLED);
 145                             out.writeInt(((SPIResolutionException) cause).id());
 146                         } else {
 147                             out.writeInt(RESULT_EXCEPTION);
 148                             out.writeUTF(cause.getClass().getName());
 149                             out.writeUTF(cause.getMessage() == null ? "<none>" : cause.getMessage());
 150                         }
 151                         out.writeInt(elems.length);
 152                         for (StackTraceElement ste : elems) {
 153                             out.writeUTF(ste.getClassName());
 154                             out.writeUTF(ste.getMethodName());
 155                             out.writeUTF(ste.getFileName() == null ? "<none>" : ste.getFileName());
 156                             out.writeInt(ste.getLineNumber());
 157                         }
 158                         out.flush();
 159                     } catch (NoSuchMethodException | IllegalAccessException ex) {
 160                         debug("*** Invoke failure: %s -- %s\n", ex, ex.getCause());
 161                         out.writeInt(RESULT_FAIL);
 162                         out.writeUTF(ex.toString());
 163                         out.flush();
 164                     } catch (StopExecutionException ex) {
 165                         try {
 166                             out.writeInt(RESULT_KILLED);
 167                             out.flush();
 168                         } catch (IOException err) {
 169                             debug("*** Error writing killed result: %s -- %s\n", ex, ex.getCause());
 170                         }
 171                     }
 172                     System.out.flush();
 173                     break;
 174                 }
 175                 case CMD_VARVALUE: {
 176                     // Retrieve a variable value
 177                     String classname = in.readUTF();
 178                     String varname = in.readUTF();
 179                     Class<?> klass = klasses.get(classname);
 180                     if (klass == null) {
 181                         debug("*** Var value failure: no such class loaded %s\n", classname);
 182                         out.writeInt(RESULT_FAIL);
 183                         out.writeUTF("no such class loaded: " + classname);
 184                         out.flush();
 185                         break;
 186                     }
 187                     try {
 188                         Field var = klass.getDeclaredField(varname);
 189                         var.setAccessible(true);
 190                         Object res = var.get(null);
 191                         out.writeInt(RESULT_SUCCESS);
 192                         out.writeUTF(valueString(res));
 193                         out.flush();
 194                     } catch (Exception ex) {
 195                         debug("*** Var value failure: no such field %s.%s\n", classname, varname);
 196                         out.writeInt(RESULT_FAIL);
 197                         out.writeUTF("no such field loaded: " + varname + " in class: " + classname);
 198                         out.flush();
 199                     }
 200                     break;
 201                 }
 202                 case CMD_CLASSPATH: {
 203                     // Append to the claspath
 204                     String cp = in.readUTF();
 205                     for (String path : cp.split(File.pathSeparator)) {
 206                         loader.addURL(new File(path).toURI().toURL());
 207                     }
 208                     out.writeInt(RESULT_SUCCESS);
 209                     out.flush();
 210                     break;
 211                 }
 212                 default:
 213                     debug("*** Bad command code: %d\n", cmd);
 214                     break;
 215             }
 216         }
 217     }
 218 
 219     // These three variables are used by the main JShell process in interrupting
 220     // the running process.  Access is via JDI, so the reference is not visible
 221     // to code inspection.
 222     private boolean inClientCode; // Queried by the main process
 223     private boolean expectingStop; // Set by the main process
 224 
 225     // thrown by the main process via JDI:
 226     private final StopExecutionException stopException = new StopExecutionException();
 227 
 228     @SuppressWarnings("serial")             // serialVersionUID intentionally omitted
 229     private class StopExecutionException extends ThreadDeath {
 230         @Override public synchronized Throwable fillInStackTrace() {
 231             return this;
 232         }
 233     }
 234 
 235     void clientCodeEnter() {
 236         expectingStop = false;
 237         inClientCode = true;
 238     }
 239 
 240     void clientCodeLeave() {
 241         inClientCode = false;
 242         while (expectingStop) {
 243             try {
 244                 Thread.sleep(0);
 245             } catch (InterruptedException ex) {
 246                 debug("*** Sleep interrupted while waiting for stop exception: %s\n", ex);
 247             }
 248         }
 249     }
 250 
 251     private void debug(String format, Object... args) {
 252         System.err.printf("REMOTE: "+format, args);
 253     }
 254 
 255     static String valueString(Object value) {
 256         if (value == null) {
 257             return "null";
 258         } else if (value instanceof String) {
 259             return "\"" + (String)value + "\"";
 260         } else if (value instanceof Character) {
 261             return "'" + value + "'";
 262         } else {
 263             return value.toString();
 264         }
 265     }
 266 
 267     private static final class MultiplexingOutputStream extends OutputStream {
 268 
 269         private static final int PACKET_SIZE = 127;
 270 
 271         private final byte[] name;
 272         private final OutputStream delegate;
 273 
 274         public MultiplexingOutputStream(String name, OutputStream delegate) {
 275             try {
 276                 this.name = name.getBytes("UTF-8");
 277                 this.delegate = delegate;
 278             } catch (UnsupportedEncodingException ex) {
 279                 throw new IllegalStateException(ex); //should not happen
 280             }
 281         }
 282 
 283         @Override
 284         public void write(int b) throws IOException {
 285             synchronized (delegate) {
 286                 delegate.write(name.length); //assuming the len is small enough to fit into byte
 287                 delegate.write(name);
 288                 delegate.write(1);
 289                 delegate.write(b);
 290                 delegate.flush();
 291             }
 292         }
 293 
 294         @Override
 295         public void write(byte[] b, int off, int len) throws IOException {
 296             synchronized (delegate) {
 297                 int i = 0;
 298                 while (len > 0) {
 299                     int size = Math.min(PACKET_SIZE, len);
 300 
 301                     delegate.write(name.length); //assuming the len is small enough to fit into byte
 302                     delegate.write(name);
 303                     delegate.write(size);
 304                     delegate.write(b, off + i, size);
 305                     i += size;
 306                     len -= size;
 307                 }
 308 
 309                 delegate.flush();
 310             }
 311         }
 312 
 313         @Override
 314         public void flush() throws IOException {
 315             super.flush();
 316             delegate.flush();
 317         }
 318 
 319         @Override
 320         public void close() throws IOException {
 321             super.close();
 322             delegate.close();
 323         }
 324 
 325     }
 326 }