001/*
002 * Copyright (C) 2010 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect.testing;
018
019import com.google.common.annotations.GwtIncompatible;
020import com.google.errorprone.annotations.CanIgnoreReturnValue;
021import java.io.Serializable;
022import java.util.AbstractSet;
023import java.util.Collection;
024import java.util.Comparator;
025import java.util.Iterator;
026import java.util.Map;
027import java.util.NavigableMap;
028import java.util.NavigableSet;
029import java.util.Set;
030import java.util.SortedMap;
031import java.util.TreeMap;
032import org.checkerframework.checker.nullness.qual.Nullable;
033
034/**
035 * A wrapper around {@code TreeMap} that aggressively checks to see if keys are mutually comparable.
036 * This implementation passes the navigable map test suites.
037 *
038 * @author Louis Wasserman
039 */
040@GwtIncompatible
041public final class SafeTreeMap<K, V> implements Serializable, NavigableMap<K, V> {
042  @SuppressWarnings("unchecked")
043  private static final Comparator<Object> NATURAL_ORDER =
044      new Comparator<Object>() {
045        @Override
046        public int compare(Object o1, Object o2) {
047          return ((Comparable<Object>) o1).compareTo(o2);
048        }
049      };
050
051  private final NavigableMap<K, V> delegate;
052
053  public SafeTreeMap() {
054    this(new TreeMap<K, V>());
055  }
056
057  public SafeTreeMap(Comparator<? super K> comparator) {
058    this(new TreeMap<K, V>(comparator));
059  }
060
061  public SafeTreeMap(Map<? extends K, ? extends V> map) {
062    this(new TreeMap<K, V>(map));
063  }
064
065  public SafeTreeMap(SortedMap<K, ? extends V> map) {
066    this(new TreeMap<K, V>(map));
067  }
068
069  private SafeTreeMap(NavigableMap<K, V> delegate) {
070    this.delegate = delegate;
071    if (delegate == null) {
072      throw new NullPointerException();
073    }
074    for (K k : keySet()) {
075      checkValid(k);
076    }
077  }
078
079  @Override
080  public @Nullable Entry<K, V> ceilingEntry(K key) {
081    return delegate.ceilingEntry(checkValid(key));
082  }
083
084  @Override
085  public @Nullable K ceilingKey(K key) {
086    return delegate.ceilingKey(checkValid(key));
087  }
088
089  @Override
090  public void clear() {
091    delegate.clear();
092  }
093
094  @Override
095  public Comparator<? super K> comparator() {
096    Comparator<? super K> comparator = delegate.comparator();
097    if (comparator == null) {
098      comparator = (Comparator<? super K>) NATURAL_ORDER;
099    }
100    return comparator;
101  }
102
103  @Override
104  public boolean containsKey(Object key) {
105    try {
106      return delegate.containsKey(checkValid(key));
107    } catch (NullPointerException | ClassCastException e) {
108      return false;
109    }
110  }
111
112  @Override
113  public boolean containsValue(Object value) {
114    return delegate.containsValue(value);
115  }
116
117  @Override
118  public NavigableSet<K> descendingKeySet() {
119    return delegate.descendingKeySet();
120  }
121
122  @Override
123  public NavigableMap<K, V> descendingMap() {
124    return new SafeTreeMap<>(delegate.descendingMap());
125  }
126
127  @Override
128  public Set<Entry<K, V>> entrySet() {
129    return new AbstractSet<Entry<K, V>>() {
130      private Set<Entry<K, V>> delegate() {
131        return delegate.entrySet();
132      }
133
134      @Override
135      public boolean contains(Object object) {
136        try {
137          return delegate().contains(object);
138        } catch (NullPointerException | ClassCastException e) {
139          return false;
140        }
141      }
142
143      @Override
144      public Iterator<Entry<K, V>> iterator() {
145        return delegate().iterator();
146      }
147
148      @Override
149      public int size() {
150        return delegate().size();
151      }
152
153      @Override
154      public boolean remove(Object o) {
155        return delegate().remove(o);
156      }
157
158      @Override
159      public void clear() {
160        delegate().clear();
161      }
162    };
163  }
164
165  @Override
166  public @Nullable Entry<K, V> firstEntry() {
167    return delegate.firstEntry();
168  }
169
170  @Override
171  public K firstKey() {
172    return delegate.firstKey();
173  }
174
175  @Override
176  public @Nullable Entry<K, V> floorEntry(K key) {
177    return delegate.floorEntry(checkValid(key));
178  }
179
180  @Override
181  public @Nullable K floorKey(K key) {
182    return delegate.floorKey(checkValid(key));
183  }
184
185  @Override
186  public @Nullable V get(Object key) {
187    return delegate.get(checkValid(key));
188  }
189
190  @Override
191  public SortedMap<K, V> headMap(K toKey) {
192    return headMap(toKey, false);
193  }
194
195  @Override
196  public NavigableMap<K, V> headMap(K toKey, boolean inclusive) {
197    return new SafeTreeMap<>(delegate.headMap(checkValid(toKey), inclusive));
198  }
199
200  @Override
201  public @Nullable Entry<K, V> higherEntry(K key) {
202    return delegate.higherEntry(checkValid(key));
203  }
204
205  @Override
206  public @Nullable K higherKey(K key) {
207    return delegate.higherKey(checkValid(key));
208  }
209
210  @Override
211  public boolean isEmpty() {
212    return delegate.isEmpty();
213  }
214
215  @Override
216  public NavigableSet<K> keySet() {
217    return navigableKeySet();
218  }
219
220  @Override
221  public @Nullable Entry<K, V> lastEntry() {
222    return delegate.lastEntry();
223  }
224
225  @Override
226  public K lastKey() {
227    return delegate.lastKey();
228  }
229
230  @Override
231  public @Nullable Entry<K, V> lowerEntry(K key) {
232    return delegate.lowerEntry(checkValid(key));
233  }
234
235  @Override
236  public @Nullable K lowerKey(K key) {
237    return delegate.lowerKey(checkValid(key));
238  }
239
240  @Override
241  public NavigableSet<K> navigableKeySet() {
242    return delegate.navigableKeySet();
243  }
244
245  @Override
246  public @Nullable Entry<K, V> pollFirstEntry() {
247    return delegate.pollFirstEntry();
248  }
249
250  @Override
251  public @Nullable Entry<K, V> pollLastEntry() {
252    return delegate.pollLastEntry();
253  }
254
255  @Override
256  public @Nullable V put(K key, V value) {
257    return delegate.put(checkValid(key), value);
258  }
259
260  @Override
261  public void putAll(Map<? extends K, ? extends V> map) {
262    for (K key : map.keySet()) {
263      checkValid(key);
264    }
265    delegate.putAll(map);
266  }
267
268  @Override
269  public @Nullable V remove(Object key) {
270    return delegate.remove(checkValid(key));
271  }
272
273  @Override
274  public int size() {
275    return delegate.size();
276  }
277
278  @Override
279  public NavigableMap<K, V> subMap(K fromKey, boolean fromInclusive, K toKey, boolean toInclusive) {
280    return new SafeTreeMap<>(
281        delegate.subMap(checkValid(fromKey), fromInclusive, checkValid(toKey), toInclusive));
282  }
283
284  @Override
285  public SortedMap<K, V> subMap(K fromKey, K toKey) {
286    return subMap(fromKey, true, toKey, false);
287  }
288
289  @Override
290  public SortedMap<K, V> tailMap(K fromKey) {
291    return tailMap(fromKey, true);
292  }
293
294  @Override
295  public NavigableMap<K, V> tailMap(K fromKey, boolean inclusive) {
296    return new SafeTreeMap<>(delegate.tailMap(checkValid(fromKey), inclusive));
297  }
298
299  @Override
300  public Collection<V> values() {
301    return delegate.values();
302  }
303
304  @CanIgnoreReturnValue
305  private <T> T checkValid(T t) {
306    // a ClassCastException is what's supposed to happen!
307    @SuppressWarnings("unchecked")
308    K k = (K) t;
309    int unused = comparator().compare(k, k);
310    return t;
311  }
312
313  @Override
314  public boolean equals(@Nullable Object obj) {
315    return delegate.equals(obj);
316  }
317
318  @Override
319  public int hashCode() {
320    return delegate.hashCode();
321  }
322
323  @Override
324  public String toString() {
325    return delegate.toString();
326  }
327
328  private static final long serialVersionUID = 0L;
329}