< prev index next >

src/jdk.jextract/share/classes/com/sun/tools/jextract/Context.java

Print this page

        

@@ -81,20 +81,23 @@
     private final List<String> libraryNames;
     // The list of library paths
     private final List<String> libraryPaths;
     // The list of library paths for link checks
     private final List<String> linkCheckPaths;
+    // Symbol patterns to be included
+    private final List<Pattern> includeSymbols;
     // Symbol patterns to be excluded
     private final List<Pattern> excludeSymbols;
     // generate static forwarder class or not?
     private boolean genStaticForwarder;
 
     final PrintWriter out;
     final PrintWriter err;
 
     private Predicate<String> symChecker;
-    private Predicate<String> symFilter;
+    private Predicate<String> includeSymFilter;
+    private Predicate<String> excludeSymFilter;
 
     private final Parser parser;
 
     private final static String defaultPkg = "jextract.dump";
     final Logger logger = Logger.getLogger(getClass().getPackage().getName());

@@ -106,10 +109,11 @@
         this.clangArgs = new ArrayList<>();
         this.sources = new TreeSet<>();
         this.libraryNames = new ArrayList<>();
         this.libraryPaths = new ArrayList<>();
         this.linkCheckPaths = new ArrayList<>();
+        this.includeSymbols = new ArrayList<>();
         this.excludeSymbols = new ArrayList<>();
         this.parser = new Parser(out, err, Main.INCLUDE_MACROS);
         this.out = out;
         this.err = err;
     }

@@ -140,10 +144,14 @@
 
     void addLinkCheckPath(String path) {
         linkCheckPaths.add(path);
     }
 
+    void addIncludeSymbols(String pattern) {
+        includeSymbols.add(Pattern.compile(pattern));
+    }
+
     void addExcludeSymbols(String pattern) {
         excludeSymbols.add(Pattern.compile(pattern));
     }
 
     void setGenStaticForwarder(boolean flag) {

@@ -216,24 +224,38 @@
 
     private boolean isSymbolFound(String name) {
         return symChecker == null? true : symChecker.test(name);
     }
 
-    private void initSymFilter() {
+    private void initSymFilters() {
+        if (!includeSymbols.isEmpty()) {
+            Pattern[] pats = includeSymbols.toArray(new Pattern[0]);
+            includeSymFilter = name -> {
+                return Arrays.stream(pats).filter(pat -> pat.matcher(name).matches()).
+                    findFirst().isPresent();
+            };
+        } else {
+            includeSymFilter = null;
+        }
+
         if (!excludeSymbols.isEmpty()) {
             Pattern[] pats = excludeSymbols.toArray(new Pattern[0]);
-            symFilter = name -> {
+            excludeSymFilter = name -> {
                 return Arrays.stream(pats).filter(pat -> pat.matcher(name).matches()).
                     findFirst().isPresent();
             };
         } else {
-            symFilter = null;
+            excludeSymFilter = null;
         }
     }
 
+    private boolean isSymbolIncluded(String name) {
+        return includeSymFilter == null? true : includeSymFilter.test(name);
+    }
+
     private boolean isSymbolExcluded(String name) {
-        return symFilter == null? false : symFilter.test(name);
+        return excludeSymFilter == null? false : excludeSymFilter.test(name);
     }
 
     /**
      * Setup a package name for a given folder.
      *

@@ -388,11 +410,11 @@
             new AsmCodeFactoryExt(this, header) : new AsmCodeFactory(this, header));
     }
 
     private boolean symbolFilter(Tree tree) {
          String name = tree.name();
-         if (isSymbolExcluded(name)) {
+         if (!isSymbolIncluded(name) || isSymbolExcluded(name)) {
              return false;
          }
 
          // check for function symbols in libraries & warn missing symbols
          if (tree instanceof FunctionTree && !isSymbolFound(name)) {

@@ -404,11 +426,11 @@
          return true;
     }
 
     public void parse(Function<HeaderFile, AsmCodeFactory> fn) {
         initSymChecker();
-        initSymFilter();
+        initSymFilters();
 
         List<HeaderTree> headers = parser.parse(sources, clangArgs);
         processHeaders(headers, fn);
     }
 
< prev index next >