< prev index next >

src/java.base/share/classes/java/util/ImmutableCollections.java

Print this page
rev 48077 : 8193128: Reduce number of implementation classes returned by List/Set/Map.of()
Reviewed-by: smarks

@@ -70,189 +70,437 @@
 
     static UnsupportedOperationException uoe() { return new UnsupportedOperationException(); }
 
     // ---------- List Implementations ----------
 
-    abstract static class AbstractImmutableList<E> extends AbstractList<E>
-                                                implements RandomAccess, Serializable {
+    static final List<?> EMPTY_LIST = new ListN<>();
+
+    @SuppressWarnings("unchecked")
+    static <E> List<E> emptyList() {
+        return (List<E>) EMPTY_LIST;
+    }
+
+    static abstract class AbstractImmutableList<E> extends AbstractCollection<E>
+            implements List<E>, RandomAccess {
+
+        // all mutating methods throw UnsupportedOperationException
         @Override public boolean add(E e) { throw uoe(); }
+        @Override public void    add(int index, E element) { throw uoe(); }
         @Override public boolean addAll(Collection<? extends E> c) { throw uoe(); }
         @Override public boolean addAll(int index, Collection<? extends E> c) { throw uoe(); }
         @Override public void    clear() { throw uoe(); }
         @Override public boolean remove(Object o) { throw uoe(); }
+        @Override public E       remove(int index) { throw uoe(); }
         @Override public boolean removeAll(Collection<?> c) { throw uoe(); }
         @Override public boolean removeIf(Predicate<? super E> filter) { throw uoe(); }
         @Override public void    replaceAll(UnaryOperator<E> operator) { throw uoe(); }
         @Override public boolean retainAll(Collection<?> c) { throw uoe(); }
+        @Override public E       set(int index, E element) { throw uoe(); }
         @Override public void    sort(Comparator<? super E> c) { throw uoe(); }
+
+        @Override
+        public boolean isEmpty() {
+            return size() == 0;
     }
 
-    static final class List0<E> extends AbstractImmutableList<E> {
-        private static final List0<?> INSTANCE = new List0<>();
+        @Override
+        public List<E> subList(int fromIndex, int toIndex) {
+            int size = size();
+            subListRangeCheck(fromIndex, toIndex, size);
+            return new SubList<E>(this, fromIndex, toIndex);
+        }
 
-        @SuppressWarnings("unchecked")
-        static <T> List0<T> instance() {
-            return (List0<T>) INSTANCE;
+        private static final class SubList<E> extends AbstractImmutableList<E> implements RandomAccess {
+            private final List<E> root;
+            final int offset;
+            int size;
+
+            /**
+             * Constructs a sublist of an arbitrary AbstractList, which is
+             * not a SubList itself.
+             */
+            SubList(List<E> root, int fromIndex, int toIndex) {
+                this.root = root;
+                this.offset = fromIndex;
+                this.size = toIndex - fromIndex;
         }
 
-        private List0() { }
+            /**
+             * Constructs a sublist of another SubList.
+             */
+            SubList(SubList<E> parent, int fromIndex, int toIndex) {
+                this.root = parent.root;
+                this.offset = parent.offset + fromIndex;
+                this.size = toIndex - fromIndex;
+            }
+
+            public E get(int index) {
+                Objects.checkIndex(index, size);
+                return root.get(offset + index);
+            }
 
-        @Override
         public int size() {
-            return 0;
+                return size;
+            }
+
+            public Iterator<E> iterator() {
+                return listIterator();
+            }
+
+            public ListIterator<E> listIterator(int index) {
+                rangeCheck(index);
+
+                ListIterator<E> i = root.listIterator(offset + index);
+
+                return new ListIterator<>() {
+
+                    public boolean hasNext() {
+                        return nextIndex() < size;
+                    }
+
+                    public E next() {
+                        if (hasNext()) {
+                            return i.next();
+                        } else {
+                            throw new NoSuchElementException();
+                        }
+                    }
+
+                    public boolean hasPrevious() {
+                        return previousIndex() >= 0;
+                    }
+
+                    public E previous() {
+                        if (hasPrevious()) {
+                            return i.previous();
+                        } else {
+                            throw new NoSuchElementException();
+                        }
+                    }
+
+                    public int nextIndex() {
+                        return i.nextIndex() - offset;
+                    }
+
+                    public int previousIndex() {
+                        return i.previousIndex() - offset;
+                    }
+
+                    public void remove() { throw uoe(); }
+                    public void set(E e) { throw uoe(); }
+                    public void add(E e) { throw uoe(); }
+                };
         }
 
         @Override
-        public E get(int index) {
-            Objects.checkIndex(index, 0); // always throws IndexOutOfBoundsException
-            return null;                  // but the compiler doesn't know this
+            public int indexOf(Object o) {
+                // Should input be required to be non-null? See JDK-8191418
+                if (o == null) {
+                    return -1;
+                }
+                ListIterator<E> it = listIterator();
+                while (it.hasNext()) {
+                    if (o.equals(it.next())) {
+                        return it.previousIndex();
+                    }
+                }
+                return -1;
         }
 
         @Override
-        public Iterator<E> iterator() {
-            return Collections.emptyIterator();
+            public int lastIndexOf(Object o) {
+                // Should input be required to be non-null? See JDK-8191418
+                if (o == null) {
+                    return -1;
         }
 
-        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
-            throw new InvalidObjectException("not serial proxy");
+                ListIterator<E> it = listIterator(size());
+                while (it.hasPrevious()) {
+                    if (o.equals(it.previous())) {
+                        return it.nextIndex();
+                    }
+                }
+                return -1;
         }
 
-        private Object writeReplace() {
-            return new CollSer(CollSer.IMM_LIST);
+            public List<E> subList(int fromIndex, int toIndex) {
+                subListRangeCheck(fromIndex, toIndex, size);
+                return new SubList<>(this, fromIndex, toIndex);
+            }
+
+            private void rangeCheck(int index) {
+                if (index < 0 || index > size) {
+                    throw outOfBounds(index);
+                }
+            }
+        }
+
+        static void subListRangeCheck(int fromIndex, int toIndex, int size) {
+            if (fromIndex < 0)
+                throw new IndexOutOfBoundsException("fromIndex = " + fromIndex);
+            if (toIndex > size)
+                throw new IndexOutOfBoundsException("toIndex = " + toIndex);
+            if (fromIndex > toIndex)
+                throw new IllegalArgumentException("fromIndex(" + fromIndex +
+                        ") > toIndex(" + toIndex + ")");
         }
 
         @Override
-        public boolean contains(Object o) {
-            Objects.requireNonNull(o);
-            return false;
+        public Iterator<E> iterator() {
+            return new Itr(size());
         }
 
         @Override
-        public boolean containsAll(Collection<?> o) {
-            return o.isEmpty(); // implicit nullcheck of o
+        public ListIterator<E> listIterator() {
+            return listIterator(0);
         }
 
         @Override
-        public int hashCode() {
-            return 1;
+        public ListIterator<E> listIterator(final int index) {
+            int size = size();
+            if (index < 0 || index > size) {
+                throw outOfBounds(index);
         }
+            return new ListItr(index, size);
     }
 
-    static final class List1<E> extends AbstractImmutableList<E> {
-        @Stable
-        private final E e0;
+        private class Itr implements Iterator<E> {
 
-        List1(E e0) {
-            this.e0 = Objects.requireNonNull(e0);
+            int cursor;
+
+            private final int size;
+
+            Itr(int size) {
+                this.size = size;
         }
 
-        @Override
-        public int size() {
-            return 1;
+            public boolean hasNext() {
+                return cursor != size;
         }
 
-        @Override
-        public E get(int index) {
-            Objects.checkIndex(index, 1);
-            return e0;
+            public E next() {
+                try {
+                    int i = cursor;
+                    E next = get(i);
+                    cursor = i + 1;
+                    return next;
+                } catch (IndexOutOfBoundsException e) {
+                    throw new NoSuchElementException();
+                }
         }
 
-        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
-            throw new InvalidObjectException("not serial proxy");
+            public void remove() {
+                throw uoe();
+            }
         }
 
-        private Object writeReplace() {
-            return new CollSer(CollSer.IMM_LIST, e0);
+        private class ListItr extends Itr implements ListIterator<E> {
+
+            ListItr(int index, int size) {
+                super(size);
+                cursor = index;
         }
 
-        @Override
-        public boolean contains(Object o) {
-            return o.equals(e0); // implicit nullcheck of o
+            public boolean hasPrevious() {
+                return cursor != 0;
+            }
+
+            public E previous() {
+                try {
+                    int i = cursor - 1;
+                    E previous = get(i);
+                    cursor = i;
+                    return previous;
+                } catch (IndexOutOfBoundsException e) {
+                    throw new NoSuchElementException();
+                }
+            }
+
+            public int nextIndex() {
+                return cursor;
+            }
+
+            public int previousIndex() {
+                return cursor - 1;
         }
 
+            public void set(E e) {
+                throw uoe();
+            }
+
+            public void add(E e) {
+                throw uoe();
+            }
+        }
+
+
         @Override
-        public int hashCode() {
-            return 31 + e0.hashCode();
+        public boolean equals(Object o) {
+            if (o == this) {
+                return true;
+            }
+
+            if (!(o instanceof List)) {
+                return false;
+            }
+
+            Iterator<?> e1 = iterator();
+            Iterator<?> e2 = ((List<?>) o).iterator();
+            while (e1.hasNext() && e2.hasNext()) {
+                Object o1 = e1.next(); // can't be null
+                Object o2 = e2.next();
+                if (o1.equals(o2)) {
+                    return false;
+                }
+            }
+            return !(e1.hasNext() || e2.hasNext());
+        }
+
+        IndexOutOfBoundsException outOfBounds(int index) {
+            return new IndexOutOfBoundsException("Index: " + index + " Size: " + size());
         }
     }
 
-    static final class List2<E> extends AbstractImmutableList<E> {
+    static final class List12<E> extends AbstractImmutableList<E> implements Serializable {
+
         @Stable
         private final E e0;
+
         @Stable
         private final E e1;
 
-        List2(E e0, E e1) {
+        List12(E e0) {
+            this.e0 = Objects.requireNonNull(e0);
+            this.e1 = null;
+        }
+
+        List12(E e0, E e1) {
             this.e0 = Objects.requireNonNull(e0);
             this.e1 = Objects.requireNonNull(e1);
         }
 
         @Override
         public int size() {
-            return 2;
+            return e1 != null ? 2 : 1;
         }
 
         @Override
         public E get(int index) {
-            Objects.checkIndex(index, 2);
             if (index == 0) {
                 return e0;
-            } else { // index == 1
+            } else if (index == 1 && e1 != null) {
                 return e1;
             }
+            throw outOfBounds(index);
         }
 
         @Override
         public boolean contains(Object o) {
-            return o.equals(e0) || o.equals(e1); // implicit nullcheck of o
+            return o.equals(e0) || o.equals(e1); // implicit null check of o
         }
 
         @Override
         public int hashCode() {
             int hash = 31 + e0.hashCode();
-            return 31 * hash + e1.hashCode();
+            if (e1 != null) {
+                hash = 31 * hash + e1.hashCode();
+            }
+            return hash;
         }
 
         private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
             throw new InvalidObjectException("not serial proxy");
         }
 
         private Object writeReplace() {
+            if (e1 == null) {
+                return new CollSer(CollSer.IMM_LIST, e0);
+            } else {
             return new CollSer(CollSer.IMM_LIST, e0, e1);
         }
     }
 
-    static final class ListN<E> extends AbstractImmutableList<E> {
+        @Override
+        public int indexOf(Object o) {
+            // Input should be checked for null, but this needs a CSR. See JDK-8191418
+            if (o == null) {
+                return -1;
+            }
+            // Objects.requireNonNull(o);
+            if (o.equals(e0)) {
+                return 0;
+            } else if (o.equals(e1)) {
+                return 1;
+            } else {
+                return -1;
+            }
+        }
+
+        @Override
+        public int lastIndexOf(Object o) {
+            // Input should be checked for null, but this needs a CSR. See JDK-8191418
+            if (o == null) {
+                return -1;
+            }
+            // Objects.requireNonNull(o);
+            if (o.equals(e1)) {
+                return 1;
+            } else if (o.equals(e0)) {
+                return 0;
+            } else {
+                return -1;
+            }
+        }
+
+    }
+
+    static final class ListN<E> extends AbstractImmutableList<E> implements Serializable {
+
         @Stable
         private final E[] elements;
 
+        @SuppressWarnings("unchecked")
+        ListN(E e0) {
+            elements = (E[])new Object[] { e0 };
+        }
+
+        @SuppressWarnings("unchecked")
+        ListN(E e0, E e1) {
+            elements = (E[])new Object[] { e0, e1 };
+        }
+
         @SafeVarargs
         ListN(E... input) {
             // copy and check manually to avoid TOCTOU
             @SuppressWarnings("unchecked")
             E[] tmp = (E[])new Object[input.length]; // implicit nullcheck of input
             for (int i = 0; i < input.length; i++) {
                 tmp[i] = Objects.requireNonNull(input[i]);
             }
-            this.elements = tmp;
+            elements = tmp;
+        }
+
+        @Override
+        public boolean isEmpty() {
+            return size() == 0;
         }
 
         @Override
         public int size() {
             return elements.length;
         }
 
         @Override
         public E get(int index) {
-            Objects.checkIndex(index, elements.length);
             return elements[index];
         }
 
         @Override
         public boolean contains(Object o) {
+            Objects.requireNonNull(o);
             for (E e : elements) {
-                if (o.equals(e)) { // implicit nullcheck of o
+                if (o.equals(e)) {
                     return true;
                 }
             }
             return false;
         }

@@ -264,120 +512,80 @@
                 hash = 31 * hash + e.hashCode();
             }
             return hash;
         }
 
-        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
-            throw new InvalidObjectException("not serial proxy");
+        @Override
+        public int indexOf(Object o) {
+            // Input should be checked for null, but this needs a CSR. See JDK-8191418
+            if (o == null) {
+                return -1;
         }
-
-        private Object writeReplace() {
-            return new CollSer(CollSer.IMM_LIST, elements);
+            // Objects.requireNonNull(o);
+            for (int i = 0; i < elements.length; i++) {
+                if (o.equals(elements[i])) {
+                    return i;
         }
     }
-
-    // ---------- Set Implementations ----------
-
-    abstract static class AbstractImmutableSet<E> extends AbstractSet<E> implements Serializable {
-        @Override public boolean add(E e) { throw uoe(); }
-        @Override public boolean addAll(Collection<? extends E> c) { throw uoe(); }
-        @Override public void    clear() { throw uoe(); }
-        @Override public boolean remove(Object o) { throw uoe(); }
-        @Override public boolean removeAll(Collection<?> c) { throw uoe(); }
-        @Override public boolean removeIf(Predicate<? super E> filter) { throw uoe(); }
-        @Override public boolean retainAll(Collection<?> c) { throw uoe(); }
+            return -1;
     }
 
-    static final class Set0<E> extends AbstractImmutableSet<E> {
-        private static final Set0<?> INSTANCE = new Set0<>();
-
-        @SuppressWarnings("unchecked")
-        static <T> Set0<T> instance() {
-            return (Set0<T>) INSTANCE;
-        }
-
-        private Set0() { }
-
         @Override
-        public int size() {
-            return 0;
+        public int lastIndexOf(Object o) {
+            // Input should be checked for null, but this needs a CSR. See JDK-8191418
+            if (o == null) {
+                return -1;
         }
-
-        @Override
-        public boolean contains(Object o) {
-            Objects.requireNonNull(o);
-            return false;
+            // Objects.requireNonNull(o);
+            for (int i = elements.length - 1; i > 0; i--) {
+                if (o.equals(elements[i])) {
+                    return i;
         }
-
-        @Override
-        public boolean containsAll(Collection<?> o) {
-            return o.isEmpty(); // implicit nullcheck of o
         }
-
-        @Override
-        public Iterator<E> iterator() {
-            return Collections.emptyIterator();
+            return -1;
         }
 
         private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
             throw new InvalidObjectException("not serial proxy");
         }
 
         private Object writeReplace() {
-            return new CollSer(CollSer.IMM_SET);
-        }
-
-        @Override
-        public int hashCode() {
-            return 0;
-        }
-    }
-
-    static final class Set1<E> extends AbstractImmutableSet<E> {
-        @Stable
-        private final E e0;
-
-        Set1(E e0) {
-            this.e0 = Objects.requireNonNull(e0);
-        }
-
-        @Override
-        public int size() {
-            return 1;
+            return new CollSer(CollSer.IMM_LIST, elements);
         }
-
-        @Override
-        public boolean contains(Object o) {
-            return o.equals(e0); // implicit nullcheck of o
         }
 
-        @Override
-        public Iterator<E> iterator() {
-            return Collections.singletonIterator(e0);
-        }
+    // ---------- Set Implementations ----------
 
-        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
-            throw new InvalidObjectException("not serial proxy");
-        }
+    static final Set<?> EMPTY_SET = new SetN<>();
 
-        private Object writeReplace() {
-            return new CollSer(CollSer.IMM_SET, e0);
+    @SuppressWarnings("unchecked")
+    static <E> Set<E> emptySet() {
+        return (Set<E>) EMPTY_SET;
         }
 
-        @Override
-        public int hashCode() {
-            return e0.hashCode();
-        }
+    abstract static class AbstractImmutableSet<E> extends AbstractSet<E> implements Serializable {
+        @Override public boolean add(E e) { throw uoe(); }
+        @Override public boolean addAll(Collection<? extends E> c) { throw uoe(); }
+        @Override public void    clear() { throw uoe(); }
+        @Override public boolean remove(Object o) { throw uoe(); }
+        @Override public boolean removeAll(Collection<?> c) { throw uoe(); }
+        @Override public boolean removeIf(Predicate<? super E> filter) { throw uoe(); }
+        @Override public boolean retainAll(Collection<?> c) { throw uoe(); }
     }
 
-    static final class Set2<E> extends AbstractImmutableSet<E> {
+    static final class Set12<E> extends AbstractImmutableSet<E> {
         @Stable
         final E e0;
         @Stable
         final E e1;
 
-        Set2(E e0, E e1) {
+        Set12(E e0) {
+            this.e0 = Objects.requireNonNull(e0);
+            this.e1 = null;
+        }
+
+        Set12(E e0, E e1) {
             if (e0.equals(Objects.requireNonNull(e1))) { // implicit nullcheck of e0
                 throw new IllegalArgumentException("duplicate element: " + e0);
             }
 
             if (SALT >= 0) {

@@ -389,40 +597,40 @@
             }
         }
 
         @Override
         public int size() {
-            return 2;
+            return (e1 == null) ? 1 : 2;
         }
 
         @Override
         public boolean contains(Object o) {
             return o.equals(e0) || o.equals(e1); // implicit nullcheck of o
         }
 
         @Override
         public int hashCode() {
-            return e0.hashCode() + e1.hashCode();
+            return e0.hashCode() + (e1 == null ? 0 : e1.hashCode());
         }
 
         @Override
         public Iterator<E> iterator() {
             return new Iterator<E>() {
-                private int idx = 0;
+                private int idx = size();
 
                 @Override
                 public boolean hasNext() {
-                    return idx < 2;
+                    return idx > 0;
                 }
 
                 @Override
                 public E next() {
-                    if (idx == 0) {
-                        idx = 1;
+                    if (idx == 1) {
+                        idx = 0;
                         return e0;
-                    } else if (idx == 1) {
-                        idx = 2;
+                    } else if (idx == 2) {
+                        idx = 1;
                         return e1;
                     } else {
                         throw new NoSuchElementException();
                     }
                 }

@@ -432,13 +640,17 @@
         private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
             throw new InvalidObjectException("not serial proxy");
         }
 
         private Object writeReplace() {
+            if (e1 == null) {
+                return new CollSer(CollSer.IMM_SET, e0);
+            } else {
             return new CollSer(CollSer.IMM_SET, e0, e1);
         }
     }
+    }
 
     /**
      * An array-based Set implementation. The element array must be strictly
      * larger than the size (the number of contained elements) so that at
      * least one null is always present.

@@ -472,11 +684,12 @@
             return size;
         }
 
         @Override
         public boolean contains(Object o) {
-            return probe(o) >= 0; // implicit nullcheck of o
+            Objects.requireNonNull(o);
+            return size > 0 && probe(o) >= 0; // implicit nullcheck of o
         }
 
         @Override
         public Iterator<E> iterator() {
             return new Iterator<E>() {

@@ -547,10 +760,17 @@
         }
     }
 
     // ---------- Map Implementations ----------
 
+    static final Map<?,?> EMPTY_MAP = new MapN<>();
+
+    @SuppressWarnings("unchecked")
+    static <K,V> Map<K,V> emptyMap() {
+        return (Map<K,V>) EMPTY_MAP;
+    }
+
     abstract static class AbstractImmutableMap<K,V> extends AbstractMap<K,V> implements Serializable {
         @Override public void clear() { throw uoe(); }
         @Override public V compute(K key, BiFunction<? super K,? super V,? extends V> rf) { throw uoe(); }
         @Override public V computeIfAbsent(K key, Function<? super K,? extends V> mf) { throw uoe(); }
         @Override public V computeIfPresent(K key, BiFunction<? super K,? super V,? extends V> rf) { throw uoe(); }

@@ -563,51 +783,10 @@
         @Override public V replace(K key, V value) { throw uoe(); }
         @Override public boolean replace(K key, V oldValue, V newValue) { throw uoe(); }
         @Override public void replaceAll(BiFunction<? super K,? super V,? extends V> f) { throw uoe(); }
     }
 
-    static final class Map0<K,V> extends AbstractImmutableMap<K,V> {
-        private static final Map0<?,?> INSTANCE = new Map0<>();
-
-        @SuppressWarnings("unchecked")
-        static <K,V> Map0<K,V> instance() {
-            return (Map0<K,V>) INSTANCE;
-        }
-
-        private Map0() { }
-
-        @Override
-        public Set<Map.Entry<K,V>> entrySet() {
-            return Set.of();
-        }
-
-        @Override
-        public boolean containsKey(Object o) {
-            Objects.requireNonNull(o);
-            return false;
-        }
-
-        @Override
-        public boolean containsValue(Object o) {
-            Objects.requireNonNull(o);
-            return false;
-        }
-
-        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
-            throw new InvalidObjectException("not serial proxy");
-        }
-
-        private Object writeReplace() {
-            return new CollSer(CollSer.IMM_MAP);
-        }
-
-        @Override
-        public int hashCode() {
-            return 0;
-        }
-    }
-
     static final class Map1<K,V> extends AbstractImmutableMap<K,V> {
         @Stable
         private final K k0;
         @Stable
         private final V v0;

@@ -656,10 +835,11 @@
      * @param <V> the value type
      */
     static final class MapN<K,V> extends AbstractImmutableMap<K,V> {
         @Stable
         final Object[] table; // pairs of key, value
+
         @Stable
         final int size; // number of pairs
 
         MapN(Object... input) {
             if ((input.length & 1) != 0) { // implicit nullcheck of input

@@ -687,11 +867,12 @@
             }
         }
 
         @Override
         public boolean containsKey(Object o) {
-            return probe(o) >= 0; // implicit nullcheck of o
+            Objects.requireNonNull(o);
+            return size > 0 && probe(o) >= 0;
         }
 
         @Override
         public boolean containsValue(Object o) {
             for (int i = 1; i < table.length; i += 2) {

@@ -716,10 +897,13 @@
         }
 
         @Override
         @SuppressWarnings("unchecked")
         public V get(Object o) {
+            if (size == 0) {
+                return null;
+            }
             int i = probe(o);
             if (i >= 0) {
                 return (V)table[i+1];
             } else {
                 return null;

@@ -946,11 +1130,11 @@
                     return List.of(array);
                 case IMM_SET:
                     return Set.of(array);
                 case IMM_MAP:
                     if (array.length == 0) {
-                        return ImmutableCollections.Map0.instance();
+                        return ImmutableCollections.emptyMap();
                     } else if (array.length == 2) {
                         return new ImmutableCollections.Map1<>(array[0], array[1]);
                     } else {
                         return new ImmutableCollections.MapN<>(array);
                     }
< prev index next >