001/*
002 * Copyright (C) 2010 The Guava Authors
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package com.google.common.collect.testing;
018
019import com.google.common.annotations.GwtIncompatible;
020import com.google.errorprone.annotations.CanIgnoreReturnValue;
021import java.io.Serializable;
022import java.util.Collection;
023import java.util.Comparator;
024import java.util.Iterator;
025import java.util.NavigableSet;
026import java.util.SortedSet;
027import java.util.TreeSet;
028import org.jspecify.annotations.Nullable;
029
030/**
031 * A wrapper around {@code TreeSet} that aggressively checks to see if elements are mutually
032 * comparable. This implementation passes the navigable set test suites.
033 *
034 * @author Louis Wasserman
035 */
036@GwtIncompatible
037public final class SafeTreeSet<E> implements Serializable, NavigableSet<E> {
038  @SuppressWarnings("unchecked")
039  private static final Comparator<Object> NATURAL_ORDER =
040      new Comparator<Object>() {
041        @Override
042        public int compare(Object o1, Object o2) {
043          return ((Comparable<Object>) o1).compareTo(o2);
044        }
045      };
046
047  private final NavigableSet<E> delegate;
048
049  public SafeTreeSet() {
050    this(new TreeSet<E>());
051  }
052
053  public SafeTreeSet(Collection<? extends E> collection) {
054    this(new TreeSet<E>(collection));
055  }
056
057  public SafeTreeSet(Comparator<? super E> comparator) {
058    this(new TreeSet<E>(comparator));
059  }
060
061  public SafeTreeSet(SortedSet<E> set) {
062    this(new TreeSet<E>(set));
063  }
064
065  private SafeTreeSet(NavigableSet<E> delegate) {
066    this.delegate = delegate;
067    for (E e : this) {
068      checkValid(e);
069    }
070  }
071
072  @Override
073  public boolean add(E element) {
074    return delegate.add(checkValid(element));
075  }
076
077  @Override
078  public boolean addAll(Collection<? extends E> collection) {
079    for (E e : collection) {
080      checkValid(e);
081    }
082    return delegate.addAll(collection);
083  }
084
085  @Override
086  public @Nullable E ceiling(E e) {
087    return delegate.ceiling(checkValid(e));
088  }
089
090  @Override
091  public void clear() {
092    delegate.clear();
093  }
094
095  @Override
096  public Comparator<? super E> comparator() {
097    Comparator<? super E> comparator = delegate.comparator();
098    if (comparator == null) {
099      comparator = (Comparator<? super E>) NATURAL_ORDER;
100    }
101    return comparator;
102  }
103
104  @Override
105  public boolean contains(Object object) {
106    return delegate.contains(checkValid(object));
107  }
108
109  @Override
110  public boolean containsAll(Collection<?> c) {
111    return delegate.containsAll(c);
112  }
113
114  @Override
115  public Iterator<E> descendingIterator() {
116    return delegate.descendingIterator();
117  }
118
119  @Override
120  public NavigableSet<E> descendingSet() {
121    return new SafeTreeSet<>(delegate.descendingSet());
122  }
123
124  @Override
125  public E first() {
126    return delegate.first();
127  }
128
129  @Override
130  public @Nullable E floor(E e) {
131    return delegate.floor(checkValid(e));
132  }
133
134  @Override
135  public SortedSet<E> headSet(E toElement) {
136    return headSet(toElement, false);
137  }
138
139  @Override
140  public NavigableSet<E> headSet(E toElement, boolean inclusive) {
141    return new SafeTreeSet<>(delegate.headSet(checkValid(toElement), inclusive));
142  }
143
144  @Override
145  public @Nullable E higher(E e) {
146    return delegate.higher(checkValid(e));
147  }
148
149  @Override
150  public boolean isEmpty() {
151    return delegate.isEmpty();
152  }
153
154  @Override
155  public Iterator<E> iterator() {
156    return delegate.iterator();
157  }
158
159  @Override
160  public E last() {
161    return delegate.last();
162  }
163
164  @Override
165  public @Nullable E lower(E e) {
166    return delegate.lower(checkValid(e));
167  }
168
169  @Override
170  public @Nullable E pollFirst() {
171    return delegate.pollFirst();
172  }
173
174  @Override
175  public @Nullable E pollLast() {
176    return delegate.pollLast();
177  }
178
179  @Override
180  public boolean remove(Object object) {
181    return delegate.remove(checkValid(object));
182  }
183
184  @Override
185  public boolean removeAll(Collection<?> c) {
186    return delegate.removeAll(c);
187  }
188
189  @Override
190  public boolean retainAll(Collection<?> c) {
191    return delegate.retainAll(c);
192  }
193
194  @Override
195  public int size() {
196    return delegate.size();
197  }
198
199  @Override
200  public NavigableSet<E> subSet(
201      E fromElement, boolean fromInclusive, E toElement, boolean toInclusive) {
202    return new SafeTreeSet<>(
203        delegate.subSet(
204            checkValid(fromElement), fromInclusive, checkValid(toElement), toInclusive));
205  }
206
207  @Override
208  public SortedSet<E> subSet(E fromElement, E toElement) {
209    return subSet(fromElement, true, toElement, false);
210  }
211
212  @Override
213  public SortedSet<E> tailSet(E fromElement) {
214    return tailSet(fromElement, true);
215  }
216
217  @Override
218  public NavigableSet<E> tailSet(E fromElement, boolean inclusive) {
219    return new SafeTreeSet<>(delegate.tailSet(checkValid(fromElement), inclusive));
220  }
221
222  @Override
223  public Object[] toArray() {
224    return delegate.toArray();
225  }
226
227  @Override
228  public <T> T[] toArray(T[] a) {
229    return delegate.toArray(a);
230  }
231
232  @CanIgnoreReturnValue
233  private <T> T checkValid(T t) {
234    // a ClassCastException is what's supposed to happen!
235    @SuppressWarnings("unchecked")
236    E e = (E) t;
237    int unused = comparator().compare(e, e);
238    return t;
239  }
240
241  @Override
242  public boolean equals(@Nullable Object obj) {
243    return delegate.equals(obj);
244  }
245
246  @Override
247  public int hashCode() {
248    return delegate.hashCode();
249  }
250
251  @Override
252  public String toString() {
253    return delegate.toString();
254  }
255
256  private static final long serialVersionUID = 0L;
257}