001/* 002 * Copyright (C) 2010 The Guava Authors 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package com.google.common.collect.testing; 018 019import com.google.common.annotations.GwtIncompatible; 020import java.io.Serializable; 021import java.util.Collection; 022import java.util.Comparator; 023import java.util.Iterator; 024import java.util.NavigableSet; 025import java.util.SortedSet; 026import java.util.TreeSet; 027 028/** 029 * A wrapper around {@code TreeSet} that aggressively checks to see if elements are mutually 030 * comparable. This implementation passes the navigable set test suites. 031 * 032 * @author Louis Wasserman 033 */ 034@GwtIncompatible 035public final class SafeTreeSet<E> implements Serializable, NavigableSet<E> { 036 @SuppressWarnings("unchecked") 037 private static final Comparator<Object> NATURAL_ORDER = 038 new Comparator<Object>() { 039 @Override 040 public int compare(Object o1, Object o2) { 041 return ((Comparable<Object>) o1).compareTo(o2); 042 } 043 }; 044 045 private final NavigableSet<E> delegate; 046 047 public SafeTreeSet() { 048 this(new TreeSet<E>()); 049 } 050 051 public SafeTreeSet(Collection<? extends E> collection) { 052 this(new TreeSet<E>(collection)); 053 } 054 055 public SafeTreeSet(Comparator<? super E> comparator) { 056 this(new TreeSet<E>(comparator)); 057 } 058 059 public SafeTreeSet(SortedSet<E> set) { 060 this(new TreeSet<E>(set)); 061 } 062 063 private SafeTreeSet(NavigableSet<E> delegate) { 064 this.delegate = delegate; 065 for (E e : this) { 066 checkValid(e); 067 } 068 } 069 070 @Override 071 public boolean add(E element) { 072 return delegate.add(checkValid(element)); 073 } 074 075 @Override 076 public boolean addAll(Collection<? extends E> collection) { 077 for (E e : collection) { 078 checkValid(e); 079 } 080 return delegate.addAll(collection); 081 } 082 083 @Override 084 public E ceiling(E e) { 085 return delegate.ceiling(checkValid(e)); 086 } 087 088 @Override 089 public void clear() { 090 delegate.clear(); 091 } 092 093 @SuppressWarnings("unchecked") 094 @Override 095 public Comparator<? super E> comparator() { 096 Comparator<? super E> comparator = delegate.comparator(); 097 if (comparator == null) { 098 comparator = (Comparator<? super E>) NATURAL_ORDER; 099 } 100 return comparator; 101 } 102 103 @Override 104 public boolean contains(Object object) { 105 return delegate.contains(checkValid(object)); 106 } 107 108 @Override 109 public boolean containsAll(Collection<?> c) { 110 return delegate.containsAll(c); 111 } 112 113 @Override 114 public Iterator<E> descendingIterator() { 115 return delegate.descendingIterator(); 116 } 117 118 @Override 119 public NavigableSet<E> descendingSet() { 120 return new SafeTreeSet<>(delegate.descendingSet()); 121 } 122 123 @Override 124 public E first() { 125 return delegate.first(); 126 } 127 128 @Override 129 public E floor(E e) { 130 return delegate.floor(checkValid(e)); 131 } 132 133 @Override 134 public SortedSet<E> headSet(E toElement) { 135 return headSet(toElement, false); 136 } 137 138 @Override 139 public NavigableSet<E> headSet(E toElement, boolean inclusive) { 140 return new SafeTreeSet<>(delegate.headSet(checkValid(toElement), inclusive)); 141 } 142 143 @Override 144 public E higher(E e) { 145 return delegate.higher(checkValid(e)); 146 } 147 148 @Override 149 public boolean isEmpty() { 150 return delegate.isEmpty(); 151 } 152 153 @Override 154 public Iterator<E> iterator() { 155 return delegate.iterator(); 156 } 157 158 @Override 159 public E last() { 160 return delegate.last(); 161 } 162 163 @Override 164 public E lower(E e) { 165 return delegate.lower(checkValid(e)); 166 } 167 168 @Override 169 public E pollFirst() { 170 return delegate.pollFirst(); 171 } 172 173 @Override 174 public E pollLast() { 175 return delegate.pollLast(); 176 } 177 178 @Override 179 public boolean remove(Object object) { 180 return delegate.remove(checkValid(object)); 181 } 182 183 @Override 184 public boolean removeAll(Collection<?> c) { 185 return delegate.removeAll(c); 186 } 187 188 @Override 189 public boolean retainAll(Collection<?> c) { 190 return delegate.retainAll(c); 191 } 192 193 @Override 194 public int size() { 195 return delegate.size(); 196 } 197 198 @Override 199 public NavigableSet<E> subSet( 200 E fromElement, boolean fromInclusive, E toElement, boolean toInclusive) { 201 return new SafeTreeSet<>( 202 delegate.subSet( 203 checkValid(fromElement), fromInclusive, checkValid(toElement), toInclusive)); 204 } 205 206 @Override 207 public SortedSet<E> subSet(E fromElement, E toElement) { 208 return subSet(fromElement, true, toElement, false); 209 } 210 211 @Override 212 public SortedSet<E> tailSet(E fromElement) { 213 return tailSet(fromElement, true); 214 } 215 216 @Override 217 public NavigableSet<E> tailSet(E fromElement, boolean inclusive) { 218 return new SafeTreeSet<>(delegate.tailSet(checkValid(fromElement), inclusive)); 219 } 220 221 @Override 222 public Object[] toArray() { 223 return delegate.toArray(); 224 } 225 226 @Override 227 public <T> T[] toArray(T[] a) { 228 return delegate.toArray(a); 229 } 230 231 private <T> T checkValid(T t) { 232 // a ClassCastException is what's supposed to happen! 233 @SuppressWarnings("unchecked") 234 E e = (E) t; 235 comparator().compare(e, e); 236 return t; 237 } 238 239 @Override 240 public boolean equals(Object obj) { 241 return delegate.equals(obj); 242 } 243 244 @Override 245 public int hashCode() { 246 return delegate.hashCode(); 247 } 248 249 @Override 250 public String toString() { 251 return delegate.toString(); 252 } 253 254 private static final long serialVersionUID = 0L; 255}