package test; import java.io.PrintStream; import java.io.Serializable; import java.lang.reflect.Modifier; import java.math.BigDecimal; import java.math.BigInteger; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; public class TypeSwitch implements Function { private final List> compiledCases = new ArrayList<>(); private Function defaultFn; static class Case { final Class type; final Function fn; @SuppressWarnings("unchecked") Case(Class type, Function fn) { this.type = type; this.fn = (Function) fn; } } // compile-time // compiledCases.get(i) invariant: distinct cases, specific 1st general later public TypeSwitch addCase(Class type, Function fn) { for (Case caze : compiledCases) { if (Types.isAssignable(caze.type, type)) { throw new IllegalArgumentException( "Compilation error: case " + type.getName() + " can not go after: case " + caze.type.getName()); } } compiledCases.add(new Case<>(type, Objects.requireNonNull(fn))); return this; } public TypeSwitch addDefault(Function defaultFn) { if (this.defaultFn != null) { throw new IllegalArgumentException( "Compilation error: duplicate default(s) are not possible"); } this.defaultFn = Objects.requireNonNull(defaultFn); return this; } // link-time // final types get a hash table private Map, Integer> finalCases; // non-final types get a decision tree private Nodes nonfinalTree; public TypeSwitch link() { System.out.println("\nCompiled cases:\n"); for (int i = 0; i < compiledCases.size(); i++) { System.out.println(compiledCases.get(i).type.getSimpleName() + " -> " + i); } // split cases into final and non-final finalCases = new HashMap<>(); List nonfinalIndexes = new ArrayList<>(); for (int i = 0; i < compiledCases.size(); i++) { Class type = compiledCases.get(i).type; if (Modifier.isFinal(type.getModifiers())) { finalCases.put(type, i); } else { nonfinalIndexes.add(i); } } System.out.println("\nFinal types hash table:\n"); finalCases .entrySet() .stream() .sorted(Comparator.comparing(Map.Entry::getValue)) .forEach(e -> System.out.println(e.getKey().getSimpleName() + " -> " + e.getValue())); // stable sort nonfinalIndex int[] comparisons = {0}; nonfinalIndexes.sort(Comparator.comparing( i -> compiledCases.get(i).type, (type1, type2) -> { comparisons[0]++; return Types.COMPARATOR.compare(type1, type2); } )); nonfinalTree = new Nodes(); // add to decision tree from most general to most specific (in reverse) for (int i = nonfinalIndexes.size() - 1; i >= 0; i--) { int r = nonfinalIndexes.get(i); Class type = compiledCases.get(r).type; nonfinalTree.add(type, r); } System.out.println("\nNon-final types re-sorted with " + comparisons[0] + " comparisons and put into decision tree:\n"); nonfinalTree.dump(System.out); return this; } // run-time @Override public R apply(Object o) { int i = mapToCaseIndex(o == null ? null : o.getClass()); if (i >= 0) { return compiledCases.get(i).fn.apply(o); } else { return defaultFn.apply(o); } } private int mapToCaseIndex(Class type) { Integer i = finalCases.get(type); if (i != null) return i; return nonfinalTree.get(type); } static class Nodes extends ArrayList { void add(Class type, int i) { for (Node node : this) { if (Types.isAssignable(node.type, type)) { node.add(type, i); return; } } // insert to beginning of list since we add in reverse order add(0, new Node(type, i)); } // return case index form given type int get(Class type) { if (type == null) return -1; // will map to default for (Node n : this) { if (Types.isAssignable(n.type, type)) { int r = n.subnodes == null ? -1 : n.subnodes.get(type); if (r >= 0) return r; return n.i; } } return -1; } void dump(PrintStream out) { dump(0, out); } void dump(int level, PrintStream out) { for (Node n : this) { n.dump(level, out); } } } static class Node { final Class type; final int i; Nodes subnodes; Node(Class type, int i) { this.type = type; this.i = i; } void add(Class subtype, int i) { if (subnodes == null) { subnodes = new Nodes(); } subnodes.add(subtype, i); } void dump(int level, PrintStream out) { out.println(" ".substring(0, level * 2) + type.getSimpleName() + " -> " + i); if (subnodes != null) { subnodes.dump(level + 1, out); } } } // simulate type injection between compile-time and link-time (separate compilation) static class Types { static final Map, Set>> isa = new HashMap<>(); static boolean isAssignable(Class target, Class source) { return target.isAssignableFrom(source) || isa .entrySet() .stream() .filter(e -> e.getKey().isAssignableFrom(source)) .flatMap(e -> e.getValue().stream()) .anyMatch(superType -> isAssignable(target, superType)); } static void injectType(Class subType, Class superType) { if (!isAssignable(superType, subType)) { isa.computeIfAbsent(subType, _sub -> new HashSet<>()).add(superType); } } static final Comparator> COMPARATOR = (type1, type2) -> { if (isAssignable(type1, type2)) { // type1 is more general than type2, type1 has to go after type2 return 1; } else if (isAssignable(type2, type1)) { // type1 is more specific than type2, type1 has to go before type2 return -1; } else { // unrelated types, stable sort will not change their relative order } return 0; }; } interface Top {} interface LightSource extends Top {} interface HeatSource extends Top {} static class LightBulb implements LightSource, HeatSource {} static class Led implements LightSource {} static class Owen implements HeatSource {} public static void main(String[] args) { TypeSwitch tsw = new TypeSwitch() .addCase(LightBulb.class, lb -> "LightBulb:" + lb) .addCase(Led.class, ld -> "Led:" + ld) .addCase(LightSource.class, ls -> "LightSource:" + ls) .addCase(Owen.class, ow -> "Owen:" + ow) .addCase(HeatSource.class, hs -> "HeatSource:" + hs) .addCase(Top.class, top -> "Top:" + top) .addCase(String.class, s -> "String:" + s) .addCase(Short.class, i -> "Short:" + i) .addCase(Integer.class, i -> "Integer:" + i) .addCase(Long.class, i -> "Long:" + i) .addCase(Float.class, i -> "Float:" + i) .addCase(Double.class, i -> "Double:" + i) .addCase(BigInteger.class, i -> "BigInteger:" + i) .addCase(BigDecimal.class, i -> "BigDecimal:" + i) .addCase(Number.class, n -> "Number:" + n) .addCase(CharSequence.class, cs -> "CharSequence:" + cs) .addCase(Serializable.class, ser -> "Serializable:" + ser) .addDefault(d -> "default:" + d); // // make LightSource extend Serializable // Types.injectType(LightSource.class, Serializable.class); // // make HeatSource extend Serializable // Types.injectType(HeatSource.class, Serializable.class); tsw.link(); System.out.println(); System.out.println(tsw.apply("HELLO 1")); System.out.println(tsw.apply(new StringBuilder("HELLO 2"))); System.out.println(tsw.apply(new Led())); } }