--- old/src/java.base/share/classes/java/util/TreeMap.java 2017-03-16 11:34:24.319877114 -0700 +++ new/src/java.base/share/classes/java/util/TreeMap.java 2017-03-16 11:34:24.203877110 -0700 @@ -29,6 +29,7 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; /** * A Red-Black tree based {@link NavigableMap} implementation. @@ -531,15 +532,173 @@ * does not permit null keys */ public V put(K key, V value) { + return put(key, value, true); + } + + private void addEntry(K key, V value, int cmp, Entry parent) { + Entry e = new Entry<>(key, value, parent); + if (cmp < 0) + parent.left = e; + else + parent.right = e; + fixAfterInsertion(e); + size++; + modCount++; + } + + private void addEntryToEmptyMap(K key, V value) { + compare(key, key); // type (and possibly null) check + root = new Entry<>(key, value, null); + size = 1; + modCount++; + } + + private V put(K key, V value, boolean replaceOld) { Entry t = root; if (t == null) { - compare(key, key); // type (and possibly null) check + addEntryToEmptyMap(key, value); + return null; + } + int cmp; + Entry parent; + // split comparator and comparable paths + Comparator cpr = comparator; + if (cpr != null) { + do { + parent = t; + cmp = cpr.compare(key, t.key); + if (cmp < 0) + t = t.left; + else if (cmp > 0) + t = t.right; + else { + V oldValue = t.value; + if(replaceOld) { + t.value = value; + } + return oldValue; + } + } while (t != null); + } + else { + if (key == null) + throw new NullPointerException(); + @SuppressWarnings("unchecked") + Comparable k = (Comparable) key; + do { + parent = t; + cmp = k.compareTo(t.key); + if (cmp < 0) + t = t.left; + else if (cmp > 0) + t = t.right; + else { + V oldValue = t.value; + if(replaceOld) { + t.value = value; + } + return oldValue; + } + } while (t != null); + } + addEntry(key, value, cmp, parent); + return null; + } + + @Override + public V putIfAbsent(K key, V value) { + return put(key, value, false); + } - root = new Entry<>(key, value, null); - size = 1; - modCount++; + @Override + public V computeIfAbsent(K key, Function mappingFunction) { + Objects.requireNonNull(mappingFunction); + V newValue; + Entry t = root; + if (t == null) { + if ((newValue = mappingFunction.apply(key)) != null) { + addEntryToEmptyMap(key, newValue); + return newValue; + } else { + return null; + } + } + int cmp; + Entry parent; + // split comparator and comparable paths + Comparator cpr = comparator; + if (cpr != null) { + do { + parent = t; + cmp = cpr.compare(key, t.key); + if (cmp < 0) + t = t.left; + else if (cmp > 0) + t = t.right; + else + return t.value; + } while (t != null); + } + else { + if (key == null) + throw new NullPointerException(); + @SuppressWarnings("unchecked") + Comparable k = (Comparable) key; + do { + parent = t; + cmp = k.compareTo(t.key); + if (cmp < 0) + t = t.left; + else if (cmp > 0) + t = t.right; + else + return t.value; + } while (t != null); + } + if ((newValue = mappingFunction.apply(key)) != null) { + addEntry(key, newValue, cmp, parent); + return newValue; + } + return null; + } + + @Override + public V computeIfPresent(K key, BiFunction remappingFunction) { + Objects.requireNonNull(remappingFunction); + Entry oldEntry = getEntry(key); + if (oldEntry != null && oldEntry.value != null) { + return remapValue(oldEntry, key, remappingFunction); + } else { return null; } + } + + private V remapValue(Entry t, K key, BiFunction remappingFunction) { + V newValue = remappingFunction.apply(key, t.value); + if (newValue == null) { + deleteEntry(t); + return null; + } else { + // replace old mapping + t.value = newValue; + return newValue; + } + } + + @Override + public V compute(K key, BiFunction remappingFunction) { + Objects.requireNonNull(remappingFunction); + V newValue; + Entry t = root; + if (t == null) { + newValue = remappingFunction.apply(key, null); + if (newValue != null) { + addEntryToEmptyMap(key, newValue); + return newValue; + } else { + return null; + } + } int cmp; Entry parent; // split comparator and comparable paths @@ -553,14 +712,14 @@ else if (cmp > 0) t = t.right; else - return t.setValue(value); + return remapValue(t, key, remappingFunction); } while (t != null); } else { if (key == null) throw new NullPointerException(); @SuppressWarnings("unchecked") - Comparable k = (Comparable) key; + Comparable k = (Comparable) key; do { parent = t; cmp = k.compareTo(t.key); @@ -569,17 +728,13 @@ else if (cmp > 0) t = t.right; else - return t.setValue(value); + return remapValue(t, key, remappingFunction); } while (t != null); } - Entry e = new Entry<>(key, value, parent); - if (cmp < 0) - parent.left = e; - else - parent.right = e; - fixAfterInsertion(e); - size++; - modCount++; + if ((newValue = remappingFunction.apply(key, null)) != null) { + addEntry(key, newValue, cmp, parent); + return newValue; + } return null; }