1 /*
   2  * Copyright (c) 2014, 2015, Oracle and/or its affiliates. All rights reserved.
   3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
   4  *
   5  * This code is free software; you can redistribute it and/or modify it
   6  * under the terms of the GNU General Public License version 2 only, as
   7  * published by the Free Software Foundation.  Oracle designates this
   8  * particular file as subject to the "Classpath" exception as provided
   9  * by Oracle in the LICENSE file that accompanied this code.
  10  *
  11  * This code is distributed in the hope that it will be useful, but WITHOUT
  12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  14  * version 2 for more details (a copy is included in the LICENSE file that
  15  * accompanied this code).
  16  *
  17  * You should have received a copy of the GNU General Public License version
  18  * 2 along with this work; if not, write to the Free Software Foundation,
  19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
  20  *
  21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
  22  * or visit www.oracle.com if you need additional information or have any
  23  * questions.
  24  */
  25 
  26 package jdk.jshell;
  27 
  28 import com.sun.source.tree.CompilationUnitTree;
  29 import com.sun.source.tree.Tree;
  30 import com.sun.source.util.Trees;
  31 import com.sun.tools.javac.api.JavacTaskImpl;
  32 import com.sun.tools.javac.api.JavacTool;
  33 import com.sun.tools.javac.util.Context;
  34 import java.util.ArrayList;
  35 import java.util.Arrays;
  36 import java.util.List;
  37 import javax.tools.Diagnostic;
  38 import javax.tools.DiagnosticCollector;
  39 import javax.tools.JavaCompiler;
  40 import javax.tools.JavaFileManager;
  41 import javax.tools.JavaFileObject;
  42 import javax.tools.ToolProvider;
  43 import static jdk.jshell.Util.*;
  44 import com.sun.source.tree.ImportTree;
  45 import com.sun.tools.javac.code.Types;
  46 import com.sun.tools.javac.util.JavacMessages;
  47 import jdk.jshell.MemoryFileManager.OutputMemoryJavaFileObject;
  48 import java.util.Collections;
  49 import java.util.Locale;
  50 import static javax.tools.StandardLocation.CLASS_OUTPUT;
  51 import static jdk.internal.jshell.debug.InternalDebugControl.DBG_GEN;
  52 import java.io.File;
  53 import java.util.Collection;
  54 import java.util.HashMap;
  55 import java.util.LinkedHashMap;
  56 import java.util.Map;
  57 import java.util.stream.Collectors;
  58 import static java.util.stream.Collectors.toList;
  59 import java.util.stream.Stream;
  60 import javax.lang.model.util.Elements;
  61 import javax.tools.FileObject;
  62 import jdk.jshell.MemoryFileManager.SourceMemoryJavaFileObject;
  63 import jdk.jshell.ClassTracker.ClassInfo;
  64 
  65 /**
  66  * The primary interface to the compiler API.  Parsing, analysis, and
  67  * compilation to class files (in memory).
  68  * @author Robert Field
  69  */
  70 class TaskFactory {
  71 
  72     private final JavaCompiler compiler;
  73     private final MemoryFileManager fileManager;
  74     private final JShell state;
  75     private String classpath = System.getProperty("java.class.path");
  76 
  77     TaskFactory(JShell state) {
  78         this.state = state;
  79         this.compiler = ToolProvider.getSystemJavaCompiler();
  80         if (compiler == null) {
  81             throw new UnsupportedOperationException("Compiler not available, must be run with full JDK 9.");
  82         }
  83         if (!System.getProperty("java.specification.version").equals("9"))  {
  84             throw new UnsupportedOperationException("Wrong compiler, must be run with full JDK 9.");
  85         }
  86         this.fileManager = new MemoryFileManager(
  87                 compiler.getStandardFileManager(null, null, null), state);
  88     }
  89 
  90     void addToClasspath(String path) {
  91         classpath = classpath + File.pathSeparator + path;
  92         List<String> args = new ArrayList<>();
  93         args.add(classpath);
  94         fileManager().handleOption("-classpath", args.iterator());
  95     }
  96 
  97     MemoryFileManager fileManager() {
  98         return fileManager;
  99     }
 100 
 101     private interface SourceHandler<T> {
 102 
 103         JavaFileObject sourceToFileObject(MemoryFileManager fm, T t);
 104 
 105         Diag diag(Diagnostic<? extends JavaFileObject> d);
 106     }
 107 
 108     private class StringSourceHandler implements SourceHandler<String> {
 109 
 110         @Override
 111         public JavaFileObject sourceToFileObject(MemoryFileManager fm, String src) {
 112             return fm.createSourceFileObject(src, "$NeverUsedName$", src);
 113         }
 114 
 115         @Override
 116         public Diag diag(final Diagnostic<? extends JavaFileObject> d) {
 117             return new Diag() {
 118 
 119                 @Override
 120                 public boolean isError() {
 121                     return d.getKind() == Diagnostic.Kind.ERROR;
 122                 }
 123 
 124                 @Override
 125                 public long getPosition() {
 126                     return d.getPosition();
 127                 }
 128 
 129                 @Override
 130                 public long getStartPosition() {
 131                     return d.getStartPosition();
 132                 }
 133 
 134                 @Override
 135                 public long getEndPosition() {
 136                     return d.getEndPosition();
 137                 }
 138 
 139                 @Override
 140                 public String getCode() {
 141                     return d.getCode();
 142                 }
 143 
 144                 @Override
 145                 public String getMessage(Locale locale) {
 146                     return expunge(d.getMessage(locale));
 147                 }
 148 
 149                 @Override
 150                 Unit unitOrNull() {
 151                     return null;
 152                 }
 153             };
 154         }
 155     }
 156 
 157     private class WrapSourceHandler implements SourceHandler<OuterWrap> {
 158 
 159         final OuterWrap wrap;
 160 
 161         WrapSourceHandler(OuterWrap wrap) {
 162             this.wrap = wrap;
 163         }
 164 
 165         @Override
 166         public JavaFileObject sourceToFileObject(MemoryFileManager fm, OuterWrap w) {
 167             return fm.createSourceFileObject(w, w.classFullName(), w.wrapped());
 168         }
 169 
 170         @Override
 171         public Diag diag(Diagnostic<? extends JavaFileObject> d) {
 172             return wrap.wrapDiag(d);
 173         }
 174     }
 175 
 176     private class UnitSourceHandler implements SourceHandler<Unit> {
 177 
 178         @Override
 179         public JavaFileObject sourceToFileObject(MemoryFileManager fm, Unit u) {
 180             return fm.createSourceFileObject(u,
 181                     state.maps.classFullName(u.snippet()),
 182                     u.snippet().outerWrap().wrapped());
 183         }
 184 
 185         @Override
 186         public Diag diag(Diagnostic<? extends JavaFileObject> d) {
 187             SourceMemoryJavaFileObject smjfo = (SourceMemoryJavaFileObject) d.getSource();
 188             Unit u = (Unit) smjfo.getOrigin();
 189             return u.snippet().outerWrap().wrapDiag(d);
 190         }
 191     }
 192 
 193     /**
 194      * Parse a snippet of code (as a String) using the parser subclass.  Return
 195      * the parse tree (and errors).
 196      */
 197     class ParseTask extends BaseTask {
 198 
 199         private final Iterable<? extends CompilationUnitTree> cuts;
 200         private final List<? extends Tree> units;
 201 
 202         ParseTask(final String source) {
 203             super(Stream.of(source),
 204                     new StringSourceHandler(),
 205                     "-XDallowStringFolding=false", "-proc:none");
 206             ReplParserFactory.instance(getContext());
 207             cuts = parse();
 208             units = Util.stream(cuts)
 209                     .flatMap(cut -> {
 210                         List<? extends ImportTree> imps = cut.getImports();
 211                         return (!imps.isEmpty() ? imps : cut.getTypeDecls()).stream();
 212                     })
 213                     .collect(toList());
 214         }
 215 
 216         private Iterable<? extends CompilationUnitTree> parse() {
 217             try {
 218                 return task.parse();
 219             } catch (Exception ex) {
 220                 throw new InternalError("Exception during parse - " + ex.getMessage(), ex);
 221             }
 222         }
 223 
 224         List<? extends Tree> units() {
 225             return units;
 226         }
 227 
 228         @Override
 229         Iterable<? extends CompilationUnitTree> cuTrees() {
 230             return cuts;
 231         }
 232     }
 233 
 234     /**
 235      * Run the normal "analyze()" pass of the compiler over the wrapped snippet.
 236      */
 237     class AnalyzeTask extends BaseTask {
 238 
 239         private final Iterable<? extends CompilationUnitTree> cuts;
 240 
 241         AnalyzeTask(final OuterWrap wrap) {
 242             this(Stream.of(wrap),
 243                     new WrapSourceHandler(wrap),
 244                     "-XDshouldStopPolicy=FLOW", "-proc:none");
 245         }
 246 
 247         AnalyzeTask(final Collection<Unit> units) {
 248             this(units.stream(), new UnitSourceHandler(),
 249                     "-XDshouldStopPolicy=FLOW", "-Xlint:unchecked", "-proc:none");
 250         }
 251 
 252         <T>AnalyzeTask(final Stream<T> stream, SourceHandler<T> sourceHandler,
 253                 String... extraOptions) {
 254             super(stream, sourceHandler, extraOptions);
 255             cuts = analyze();
 256         }
 257 
 258         private Iterable<? extends CompilationUnitTree> analyze() {
 259             try {
 260                 Iterable<? extends CompilationUnitTree> cuts = task.parse();
 261                 task.analyze();
 262                 return cuts;
 263             } catch (Exception ex) {
 264                 throw new InternalError("Exception during analyze - " + ex.getMessage(), ex);
 265             }
 266         }
 267 
 268         @Override
 269         Iterable<? extends CompilationUnitTree> cuTrees() {
 270             return cuts;
 271         }
 272 
 273         Elements getElements() {
 274             return task.getElements();
 275         }
 276 
 277         javax.lang.model.util.Types getTypes() {
 278             return task.getTypes();
 279         }
 280     }
 281 
 282     /**
 283      * Unit the wrapped snippet to class files.
 284      */
 285     class CompileTask extends BaseTask {
 286 
 287         private final Map<Unit, List<OutputMemoryJavaFileObject>> classObjs = new HashMap<>();
 288 
 289         CompileTask(Collection<Unit> units) {
 290             super(units.stream(), new UnitSourceHandler(),
 291                     "-Xlint:unchecked", "-proc:none");
 292         }
 293 
 294         boolean compile() {
 295             fileManager.registerClassFileCreationListener(this::listenForNewClassFile);
 296             boolean result = task.call();
 297             fileManager.registerClassFileCreationListener(null);
 298             return result;
 299         }
 300 
 301 
 302         List<ClassInfo> classInfoList(Unit u) {
 303             List<OutputMemoryJavaFileObject> l = classObjs.get(u);
 304             if (l == null) return Collections.emptyList();
 305             return l.stream()
 306                     .map(fo -> state.classTracker.classInfo(fo.getName(), fo.getBytes()))
 307                     .collect(Collectors.toList());
 308         }
 309 
 310         private void listenForNewClassFile(OutputMemoryJavaFileObject jfo, JavaFileManager.Location location,
 311                 String className, JavaFileObject.Kind kind, FileObject sibling) {
 312             //debug("listenForNewClassFile %s loc=%s kind=%s\n", className, location, kind);
 313             if (location == CLASS_OUTPUT) {
 314                 state.debug(DBG_GEN, "Compiler generating class %s\n", className);
 315                 Unit u = ((sibling instanceof SourceMemoryJavaFileObject)
 316                         && (((SourceMemoryJavaFileObject) sibling).getOrigin() instanceof Unit))
 317                         ? (Unit) ((SourceMemoryJavaFileObject) sibling).getOrigin()
 318                         : null;
 319                 classObjs.compute(u, (k, v) -> (v == null)? new ArrayList<>() : v)
 320                         .add(jfo);
 321             }
 322         }
 323 
 324         @Override
 325         Iterable<? extends CompilationUnitTree> cuTrees() {
 326             throw new UnsupportedOperationException("Not supported.");
 327         }
 328     }
 329 
 330     abstract class BaseTask {
 331 
 332         final DiagnosticCollector<JavaFileObject> diagnostics = new DiagnosticCollector<>();
 333         final JavacTaskImpl task;
 334         private DiagList diags = null;
 335         private final SourceHandler<?> sourceHandler;
 336         private final Context context = new Context();
 337         private Types types;
 338         private JavacMessages messages;
 339         private Trees trees;
 340 
 341         private <T>BaseTask(Stream<T> inputs,
 342                 //BiFunction<MemoryFileManager, T, JavaFileObject> sfoCreator,
 343                 SourceHandler<T> sh,
 344                 String... extraOptions) {
 345             this.sourceHandler = sh;
 346             List<String> options = Arrays.asList(extraOptions);
 347             Iterable<? extends JavaFileObject> compilationUnits = inputs
 348                             .map(in -> sh.sourceToFileObject(fileManager, in))
 349                             .collect(Collectors.toList());
 350             this.task = (JavacTaskImpl) ((JavacTool) compiler).getTask(null,
 351                     fileManager, diagnostics, options, null,
 352                     compilationUnits, context);
 353         }
 354 
 355         abstract Iterable<? extends CompilationUnitTree> cuTrees();
 356 
 357         CompilationUnitTree firstCuTree() {
 358             return cuTrees().iterator().next();
 359         }
 360 
 361         Diag diag(Diagnostic<? extends JavaFileObject> diag) {
 362             return sourceHandler.diag(diag);
 363         }
 364 
 365         Context getContext() {
 366             return context;
 367         }
 368 
 369         Types types() {
 370             if (types == null) {
 371                 types = Types.instance(context);
 372             }
 373             return types;
 374         }
 375 
 376         JavacMessages messages() {
 377             if (messages == null) {
 378                 messages = JavacMessages.instance(context);
 379             }
 380             return messages;
 381         }
 382 
 383         Trees trees() {
 384             if (trees == null) {
 385                 trees = Trees.instance(task);
 386             }
 387             return trees;
 388         }
 389 
 390         // ------------------ diags functionality
 391 
 392         DiagList getDiagnostics() {
 393             if (diags == null) {
 394                 LinkedHashMap<String, Diag> diagMap = new LinkedHashMap<>();
 395                 for (Diagnostic<? extends JavaFileObject> in : diagnostics.getDiagnostics()) {
 396                     Diag d = diag(in);
 397                     String uniqueKey = d.getCode() + ":" + d.getPosition() + ":" + d.getMessage(PARSED_LOCALE);
 398                     diagMap.put(uniqueKey, d);
 399                 }
 400                 diags = new DiagList(diagMap.values());
 401             }
 402             return diags;
 403         }
 404 
 405         boolean hasErrors() {
 406             return getDiagnostics().hasErrors();
 407         }
 408 
 409         String shortErrorMessage() {
 410             StringBuilder sb = new StringBuilder();
 411             for (Diag diag : getDiagnostics()) {
 412                 for (String line : diag.getMessage(PARSED_LOCALE).split("\\r?\\n")) {
 413                     if (!line.trim().startsWith("location:")) {
 414                         sb.append(line);
 415                     }
 416                 }
 417             }
 418             return sb.toString();
 419         }
 420 
 421         void debugPrintDiagnostics(String src) {
 422             for (Diag diag : getDiagnostics()) {
 423                 state.debug(DBG_GEN, "ERROR --\n");
 424                 for (String line : diag.getMessage(PARSED_LOCALE).split("\\r?\\n")) {
 425                     if (!line.trim().startsWith("location:")) {
 426                         state.debug(DBG_GEN, "%s\n", line);
 427                     }
 428                 }
 429                 int start = (int) diag.getStartPosition();
 430                 int end = (int) diag.getEndPosition();
 431                 if (src != null) {
 432                     String[] srcLines = src.split("\\r?\\n");
 433                     for (String line : srcLines) {
 434                         state.debug(DBG_GEN, "%s\n", line);
 435                     }
 436 
 437                     StringBuilder sb = new StringBuilder();
 438                     for (int i = 0; i < start; ++i) {
 439                         sb.append(' ');
 440                     }
 441                     sb.append('^');
 442                     if (end > start) {
 443                         for (int i = start + 1; i < end; ++i) {
 444                             sb.append('-');
 445                         }
 446                         sb.append('^');
 447                     }
 448                     state.debug(DBG_GEN, "%s\n", sb.toString());
 449                 }
 450                 state.debug(DBG_GEN, "printDiagnostics start-pos = %d ==> %d -- wrap = %s\n",
 451                         diag.getStartPosition(), start, this);
 452                 state.debug(DBG_GEN, "Code: %s\n", diag.getCode());
 453                 state.debug(DBG_GEN, "Pos: %d (%d - %d) -- %s\n", diag.getPosition(),
 454                         diag.getStartPosition(), diag.getEndPosition(), diag.getMessage(null));
 455             }
 456         }
 457     }
 458 
 459 }