001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.stat.clustering;
019    
020    import java.util.ArrayList;
021    import java.util.Collection;
022    import java.util.List;
023    import java.util.Random;
024    
025    import org.apache.commons.math.exception.ConvergenceException;
026    import org.apache.commons.math.exception.util.LocalizedFormats;
027    import org.apache.commons.math.stat.descriptive.moment.Variance;
028    
029    /**
030     * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
031     * @param <T> type of the points to cluster
032     * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
033     * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $
034     * @since 2.0
035     */
036    public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
037    
038        /** Strategies to use for replacing an empty cluster. */
039        public static enum EmptyClusterStrategy {
040    
041            /** Split the cluster with largest distance variance. */
042            LARGEST_VARIANCE,
043    
044            /** Split the cluster with largest number of points. */
045            LARGEST_POINTS_NUMBER,
046    
047            /** Create a cluster around the point farthest from its centroid. */
048            FARTHEST_POINT,
049    
050            /** Generate an error. */
051            ERROR
052    
053        }
054    
055        /** Random generator for choosing initial centers. */
056        private final Random random;
057    
058        /** Selected strategy for empty clusters. */
059        private final EmptyClusterStrategy emptyStrategy;
060    
061        /** Build a clusterer.
062         * <p>
063         * The default strategy for handling empty clusters that may appear during
064         * algorithm iterations is to split the cluster with largest distance variance.
065         * </p>
066         * @param random random generator to use for choosing initial centers
067         */
068        public KMeansPlusPlusClusterer(final Random random) {
069            this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
070        }
071    
072        /** Build a clusterer.
073         * @param random random generator to use for choosing initial centers
074         * @param emptyStrategy strategy to use for handling empty clusters that
075         * may appear during algorithm iterations
076         * @since 2.2
077         */
078        public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
079            this.random        = random;
080            this.emptyStrategy = emptyStrategy;
081        }
082    
083        /**
084         * Runs the K-means++ clustering algorithm.
085         *
086         * @param points the points to cluster
087         * @param k the number of clusters to split the data into
088         * @param maxIterations the maximum number of iterations to run the algorithm
089         *     for.  If negative, no maximum will be used
090         * @return a list of clusters containing the points
091         */
092        public List<Cluster<T>> cluster(final Collection<T> points,
093                                        final int k, final int maxIterations) {
094            // create the initial clusters
095            List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
096            assignPointsToClusters(clusters, points);
097    
098            // iterate through updating the centers until we're done
099            final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
100            for (int count = 0; count < max; count++) {
101                boolean clusteringChanged = false;
102                List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
103                for (final Cluster<T> cluster : clusters) {
104                    final T newCenter;
105                    if (cluster.getPoints().isEmpty()) {
106                        switch (emptyStrategy) {
107                            case LARGEST_VARIANCE :
108                                newCenter = getPointFromLargestVarianceCluster(clusters);
109                                break;
110                            case LARGEST_POINTS_NUMBER :
111                                newCenter = getPointFromLargestNumberCluster(clusters);
112                                break;
113                            case FARTHEST_POINT :
114                                newCenter = getFarthestPoint(clusters);
115                                break;
116                            default :
117                                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
118                        }
119                        clusteringChanged = true;
120                    } else {
121                        newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
122                        if (!newCenter.equals(cluster.getCenter())) {
123                            clusteringChanged = true;
124                        }
125                    }
126                    newClusters.add(new Cluster<T>(newCenter));
127                }
128                if (!clusteringChanged) {
129                    return clusters;
130                }
131                assignPointsToClusters(newClusters, points);
132                clusters = newClusters;
133            }
134            return clusters;
135        }
136    
137        /**
138         * Adds the given points to the closest {@link Cluster}.
139         *
140         * @param <T> type of the points to cluster
141         * @param clusters the {@link Cluster}s to add the points to
142         * @param points the points to add to the given {@link Cluster}s
143         */
144        private static <T extends Clusterable<T>> void
145            assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
146            for (final T p : points) {
147                Cluster<T> cluster = getNearestCluster(clusters, p);
148                cluster.addPoint(p);
149            }
150        }
151    
152        /**
153         * Use K-means++ to choose the initial centers.
154         *
155         * @param <T> type of the points to cluster
156         * @param points the points to choose the initial centers from
157         * @param k the number of centers to choose
158         * @param random random generator to use
159         * @return the initial centers
160         */
161        private static <T extends Clusterable<T>> List<Cluster<T>>
162            chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
163    
164            final List<T> pointSet = new ArrayList<T>(points);
165            final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
166    
167            // Choose one center uniformly at random from among the data points.
168            final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
169            resultSet.add(new Cluster<T>(firstPoint));
170    
171            final double[] dx2 = new double[pointSet.size()];
172            while (resultSet.size() < k) {
173                // For each data point x, compute D(x), the distance between x and
174                // the nearest center that has already been chosen.
175                int sum = 0;
176                for (int i = 0; i < pointSet.size(); i++) {
177                    final T p = pointSet.get(i);
178                    final Cluster<T> nearest = getNearestCluster(resultSet, p);
179                    final double d = p.distanceFrom(nearest.getCenter());
180                    sum += d * d;
181                    dx2[i] = sum;
182                }
183    
184                // Add one new data point as a center. Each point x is chosen with
185                // probability proportional to D(x)2
186                final double r = random.nextDouble() * sum;
187                for (int i = 0 ; i < dx2.length; i++) {
188                    if (dx2[i] >= r) {
189                        final T p = pointSet.remove(i);
190                        resultSet.add(new Cluster<T>(p));
191                        break;
192                    }
193                }
194            }
195    
196            return resultSet;
197    
198        }
199    
200        /**
201         * Get a random point from the {@link Cluster} with the largest distance variance.
202         *
203         * @param clusters the {@link Cluster}s to search
204         * @return a random point from the selected cluster
205         */
206        private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
207    
208            double maxVariance = Double.NEGATIVE_INFINITY;
209            Cluster<T> selected = null;
210            for (final Cluster<T> cluster : clusters) {
211                if (!cluster.getPoints().isEmpty()) {
212    
213                    // compute the distance variance of the current cluster
214                    final T center = cluster.getCenter();
215                    final Variance stat = new Variance();
216                    for (final T point : cluster.getPoints()) {
217                        stat.increment(point.distanceFrom(center));
218                    }
219                    final double variance = stat.getResult();
220    
221                    // select the cluster with the largest variance
222                    if (variance > maxVariance) {
223                        maxVariance = variance;
224                        selected = cluster;
225                    }
226    
227                }
228            }
229    
230            // did we find at least one non-empty cluster ?
231            if (selected == null) {
232                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
233            }
234    
235            // extract a random point from the cluster
236            final List<T> selectedPoints = selected.getPoints();
237            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
238    
239        }
240    
241        /**
242         * Get a random point from the {@link Cluster} with the largest number of points
243         *
244         * @param clusters the {@link Cluster}s to search
245         * @return a random point from the selected cluster
246         */
247        private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
248    
249            int maxNumber = 0;
250            Cluster<T> selected = null;
251            for (final Cluster<T> cluster : clusters) {
252    
253                // get the number of points of the current cluster
254                final int number = cluster.getPoints().size();
255    
256                // select the cluster with the largest number of points
257                if (number > maxNumber) {
258                    maxNumber = number;
259                    selected = cluster;
260                }
261    
262            }
263    
264            // did we find at least one non-empty cluster ?
265            if (selected == null) {
266                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
267            }
268    
269            // extract a random point from the cluster
270            final List<T> selectedPoints = selected.getPoints();
271            return selectedPoints.remove(random.nextInt(selectedPoints.size()));
272    
273        }
274    
275        /**
276         * Get the point farthest to its cluster center
277         *
278         * @param clusters the {@link Cluster}s to search
279         * @return point farthest to its cluster center
280         */
281        private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
282    
283            double maxDistance = Double.NEGATIVE_INFINITY;
284            Cluster<T> selectedCluster = null;
285            int selectedPoint = -1;
286            for (final Cluster<T> cluster : clusters) {
287    
288                // get the farthest point
289                final T center = cluster.getCenter();
290                final List<T> points = cluster.getPoints();
291                for (int i = 0; i < points.size(); ++i) {
292                    final double distance = points.get(i).distanceFrom(center);
293                    if (distance > maxDistance) {
294                        maxDistance     = distance;
295                        selectedCluster = cluster;
296                        selectedPoint   = i;
297                    }
298                }
299    
300            }
301    
302            // did we find at least one non-empty cluster ?
303            if (selectedCluster == null) {
304                throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
305            }
306    
307            return selectedCluster.getPoints().remove(selectedPoint);
308    
309        }
310    
311        /**
312         * Returns the nearest {@link Cluster} to the given point
313         *
314         * @param <T> type of the points to cluster
315         * @param clusters the {@link Cluster}s to search
316         * @param point the point to find the nearest {@link Cluster} for
317         * @return the nearest {@link Cluster} to the given point
318         */
319        private static <T extends Clusterable<T>> Cluster<T>
320            getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
321            double minDistance = Double.MAX_VALUE;
322            Cluster<T> minCluster = null;
323            for (final Cluster<T> c : clusters) {
324                final double distance = point.distanceFrom(c.getCenter());
325                if (distance < minDistance) {
326                    minDistance = distance;
327                    minCluster = c;
328                }
329            }
330            return minCluster;
331        }
332    
333    }