001/*
002 * Copyright (C) 2010 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect.testing;
018
019import com.google.common.annotations.GwtIncompatible;
020import java.io.Serializable;
021import java.util.AbstractSet;
022import java.util.Collection;
023import java.util.Comparator;
024import java.util.Iterator;
025import java.util.Map;
026import java.util.NavigableMap;
027import java.util.NavigableSet;
028import java.util.Set;
029import java.util.SortedMap;
030import java.util.TreeMap;
031
032/**
033 * A wrapper around {@code TreeMap} that aggressively checks to see if keys are mutually comparable.
034 * This implementation passes the navigable map test suites.
035 *
036 * @author Louis Wasserman
037 */
038@GwtIncompatible
039public final class SafeTreeMap<K, V> implements Serializable, NavigableMap<K, V> {
040  @SuppressWarnings("unchecked")
041  private static final Comparator<Object> NATURAL_ORDER =
042      new Comparator<Object>() {
043        @Override
044        public int compare(Object o1, Object o2) {
045          return ((Comparable<Object>) o1).compareTo(o2);
046        }
047      };
048
049  private final NavigableMap<K, V> delegate;
050
051  public SafeTreeMap() {
052    this(new TreeMap<K, V>());
053  }
054
055  public SafeTreeMap(Comparator<? super K> comparator) {
056    this(new TreeMap<K, V>(comparator));
057  }
058
059  public SafeTreeMap(Map<? extends K, ? extends V> map) {
060    this(new TreeMap<K, V>(map));
061  }
062
063  public SafeTreeMap(SortedMap<K, ? extends V> map) {
064    this(new TreeMap<K, V>(map));
065  }
066
067  private SafeTreeMap(NavigableMap<K, V> delegate) {
068    this.delegate = delegate;
069    if (delegate == null) {
070      throw new NullPointerException();
071    }
072    for (K k : keySet()) {
073      checkValid(k);
074    }
075  }
076
077  @Override
078  public Entry<K, V> ceilingEntry(K key) {
079    return delegate.ceilingEntry(checkValid(key));
080  }
081
082  @Override
083  public K ceilingKey(K key) {
084    return delegate.ceilingKey(checkValid(key));
085  }
086
087  @Override
088  public void clear() {
089    delegate.clear();
090  }
091
092  @SuppressWarnings("unchecked")
093  @Override
094  public Comparator<? super K> comparator() {
095    Comparator<? super K> comparator = delegate.comparator();
096    if (comparator == null) {
097      comparator = (Comparator<? super K>) NATURAL_ORDER;
098    }
099    return comparator;
100  }
101
102  @Override
103  public boolean containsKey(Object key) {
104    try {
105      return delegate.containsKey(checkValid(key));
106    } catch (NullPointerException | ClassCastException e) {
107      return false;
108    }
109  }
110
111  @Override
112  public boolean containsValue(Object value) {
113    return delegate.containsValue(value);
114  }
115
116  @Override
117  public NavigableSet<K> descendingKeySet() {
118    return delegate.descendingKeySet();
119  }
120
121  @Override
122  public NavigableMap<K, V> descendingMap() {
123    return new SafeTreeMap<>(delegate.descendingMap());
124  }
125
126  @Override
127  public Set<Entry<K, V>> entrySet() {
128    return new AbstractSet<Entry<K, V>>() {
129      private Set<Entry<K, V>> delegate() {
130        return delegate.entrySet();
131      }
132
133      @Override
134      public boolean contains(Object object) {
135        try {
136          return delegate().contains(object);
137        } catch (NullPointerException | ClassCastException e) {
138          return false;
139        }
140      }
141
142      @Override
143      public Iterator<Entry<K, V>> iterator() {
144        return delegate().iterator();
145      }
146
147      @Override
148      public int size() {
149        return delegate().size();
150      }
151
152      @Override
153      public boolean remove(Object o) {
154        return delegate().remove(o);
155      }
156
157      @Override
158      public void clear() {
159        delegate().clear();
160      }
161    };
162  }
163
164  @Override
165  public Entry<K, V> firstEntry() {
166    return delegate.firstEntry();
167  }
168
169  @Override
170  public K firstKey() {
171    return delegate.firstKey();
172  }
173
174  @Override
175  public Entry<K, V> floorEntry(K key) {
176    return delegate.floorEntry(checkValid(key));
177  }
178
179  @Override
180  public K floorKey(K key) {
181    return delegate.floorKey(checkValid(key));
182  }
183
184  @Override
185  public V get(Object key) {
186    return delegate.get(checkValid(key));
187  }
188
189  @Override
190  public SortedMap<K, V> headMap(K toKey) {
191    return headMap(toKey, false);
192  }
193
194  @Override
195  public NavigableMap<K, V> headMap(K toKey, boolean inclusive) {
196    return new SafeTreeMap<>(delegate.headMap(checkValid(toKey), inclusive));
197  }
198
199  @Override
200  public Entry<K, V> higherEntry(K key) {
201    return delegate.higherEntry(checkValid(key));
202  }
203
204  @Override
205  public K higherKey(K key) {
206    return delegate.higherKey(checkValid(key));
207  }
208
209  @Override
210  public boolean isEmpty() {
211    return delegate.isEmpty();
212  }
213
214  @Override
215  public NavigableSet<K> keySet() {
216    return navigableKeySet();
217  }
218
219  @Override
220  public Entry<K, V> lastEntry() {
221    return delegate.lastEntry();
222  }
223
224  @Override
225  public K lastKey() {
226    return delegate.lastKey();
227  }
228
229  @Override
230  public Entry<K, V> lowerEntry(K key) {
231    return delegate.lowerEntry(checkValid(key));
232  }
233
234  @Override
235  public K lowerKey(K key) {
236    return delegate.lowerKey(checkValid(key));
237  }
238
239  @Override
240  public NavigableSet<K> navigableKeySet() {
241    return delegate.navigableKeySet();
242  }
243
244  @Override
245  public Entry<K, V> pollFirstEntry() {
246    return delegate.pollFirstEntry();
247  }
248
249  @Override
250  public Entry<K, V> pollLastEntry() {
251    return delegate.pollLastEntry();
252  }
253
254  @Override
255  public V put(K key, V value) {
256    return delegate.put(checkValid(key), value);
257  }
258
259  @Override
260  public void putAll(Map<? extends K, ? extends V> map) {
261    for (K key : map.keySet()) {
262      checkValid(key);
263    }
264    delegate.putAll(map);
265  }
266
267  @Override
268  public V remove(Object key) {
269    return delegate.remove(checkValid(key));
270  }
271
272  @Override
273  public int size() {
274    return delegate.size();
275  }
276
277  @Override
278  public NavigableMap<K, V> subMap(K fromKey, boolean fromInclusive, K toKey, boolean toInclusive) {
279    return new SafeTreeMap<>(
280        delegate.subMap(checkValid(fromKey), fromInclusive, checkValid(toKey), toInclusive));
281  }
282
283  @Override
284  public SortedMap<K, V> subMap(K fromKey, K toKey) {
285    return subMap(fromKey, true, toKey, false);
286  }
287
288  @Override
289  public SortedMap<K, V> tailMap(K fromKey) {
290    return tailMap(fromKey, true);
291  }
292
293  @Override
294  public NavigableMap<K, V> tailMap(K fromKey, boolean inclusive) {
295    return new SafeTreeMap<>(delegate.tailMap(checkValid(fromKey), inclusive));
296  }
297
298  @Override
299  public Collection<V> values() {
300    return delegate.values();
301  }
302
303  private <T> T checkValid(T t) {
304    // a ClassCastException is what's supposed to happen!
305    @SuppressWarnings("unchecked")
306    K k = (K) t;
307    comparator().compare(k, k);
308    return t;
309  }
310
311  @Override
312  public boolean equals(Object obj) {
313    return delegate.equals(obj);
314  }
315
316  @Override
317  public int hashCode() {
318    return delegate.hashCode();
319  }
320
321  @Override
322  public String toString() {
323    return delegate.toString();
324  }
325
326  private static final long serialVersionUID = 0L;
327}