001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.regression;
018    
019    import org.apache.commons.math.MathRuntimeException;
020    import org.apache.commons.math.exception.util.LocalizedFormats;
021    import org.apache.commons.math.linear.RealMatrix;
022    import org.apache.commons.math.linear.Array2DRowRealMatrix;
023    import org.apache.commons.math.linear.RealVector;
024    import org.apache.commons.math.linear.ArrayRealVector;
025    import org.apache.commons.math.stat.descriptive.moment.Variance;
026    import org.apache.commons.math.util.FastMath;
027    
028    /**
029     * Abstract base class for implementations of MultipleLinearRegression.
030     * @version $Revision: 1073459 $ $Date: 2011-02-22 20:18:12 +0100 (mar. 22 f??vr. 2011) $
031     * @since 2.0
032     */
033    public abstract class AbstractMultipleLinearRegression implements
034            MultipleLinearRegression {
035    
036        /** X sample data. */
037        protected RealMatrix X;
038    
039        /** Y sample data. */
040        protected RealVector Y;
041    
042        /** Whether or not the regression model includes an intercept.  True means no intercept. */
043        private boolean noIntercept = false;
044    
045        /**
046         * @return true if the model has no intercept term; false otherwise
047         * @since 2.2
048         */
049        public boolean isNoIntercept() {
050            return noIntercept;
051        }
052    
053        /**
054         * @param noIntercept true means the model is to be estimated without an intercept term
055         * @since 2.2
056         */
057        public void setNoIntercept(boolean noIntercept) {
058            this.noIntercept = noIntercept;
059        }
060    
061        /**
062         * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
063         * </p>
064         * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
065         * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
066         * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
067         * independent variables, as below:
068         * <pre>
069         *   y   x[0]  x[1]
070         *   --------------
071         *   1     2     3
072         *   4     5     6
073         *   7     8     9
074         * </pre>
075         * </p>
076         * <p>Note that there is no need to add an initial unitary column (column of 1's) when
077         * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
078         * the X matrix will be created without an initial column of "1"s; otherwise this column will
079         * be added.
080         * </p>
081         * <p>Throws IllegalArgumentException if any of the following preconditions fail:
082         * <ul><li><code>data</code> cannot be null</li>
083         * <li><code>data.length = nobs * (nvars + 1)</li>
084         * <li><code>nobs > nvars</code></li></ul>
085         * </p>
086         *
087         * @param data input data array
088         * @param nobs number of observations (rows)
089         * @param nvars number of independent variables (columns, not counting y)
090         * @throws IllegalArgumentException if the preconditions are not met
091         */
092        public void newSampleData(double[] data, int nobs, int nvars) {
093            if (data == null) {
094                throw MathRuntimeException.createIllegalArgumentException(
095                        LocalizedFormats.NULL_NOT_ALLOWED);
096            }
097            if (data.length != nobs * (nvars + 1)) {
098                throw MathRuntimeException.createIllegalArgumentException(
099                        LocalizedFormats.INVALID_REGRESSION_ARRAY, data.length, nobs, nvars);
100            }
101            if (nobs <= nvars) {
102                throw MathRuntimeException.createIllegalArgumentException(
103                        LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS);
104            }
105            double[] y = new double[nobs];
106            final int cols = noIntercept ? nvars: nvars + 1;
107            double[][] x = new double[nobs][cols];
108            int pointer = 0;
109            for (int i = 0; i < nobs; i++) {
110                y[i] = data[pointer++];
111                if (!noIntercept) {
112                    x[i][0] = 1.0d;
113                }
114                for (int j = noIntercept ? 0 : 1; j < cols; j++) {
115                    x[i][j] = data[pointer++];
116                }
117            }
118            this.X = new Array2DRowRealMatrix(x);
119            this.Y = new ArrayRealVector(y);
120        }
121    
122        /**
123         * Loads new y sample data, overriding any previous data.
124         *
125         * @param y the array representing the y sample
126         * @throws IllegalArgumentException if y is null or empty
127         */
128        protected void newYSampleData(double[] y) {
129            if (y == null) {
130                throw MathRuntimeException.createIllegalArgumentException(
131                        LocalizedFormats.NULL_NOT_ALLOWED);
132            }
133            if (y.length == 0) {
134                throw MathRuntimeException.createIllegalArgumentException(
135                        LocalizedFormats.NO_DATA);
136            }
137            this.Y = new ArrayRealVector(y);
138        }
139    
140        /**
141         * <p>Loads new x sample data, overriding any previous data.
142         * </p>
143         * The input <code>x</code> array should have one row for each sample
144         * observation, with columns corresponding to independent variables.
145         * For example, if <pre>
146         * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
147         * then <code>setXSampleData(x) </code> results in a model with two independent
148         * variables and 3 observations:
149         * <pre>
150         *   x[0]  x[1]
151         *   ----------
152         *     1    2
153         *     3    4
154         *     5    6
155         * </pre>
156         * </p>
157         * <p>Note that there is no need to add an initial unitary column (column of 1's) when
158         * specifying a model including an intercept term.
159         * </p>
160         * @param x the rectangular array representing the x sample
161         * @throws IllegalArgumentException if x is null, empty or not rectangular
162         */
163        protected void newXSampleData(double[][] x) {
164            if (x == null) {
165                throw MathRuntimeException.createIllegalArgumentException(
166                        LocalizedFormats.NULL_NOT_ALLOWED);
167            }
168            if (x.length == 0) {
169                throw MathRuntimeException.createIllegalArgumentException(
170                        LocalizedFormats.NO_DATA);
171            }
172            if (noIntercept) {
173                this.X = new Array2DRowRealMatrix(x, true);
174            } else { // Augment design matrix with initial unitary column
175                final int nVars = x[0].length;
176                final double[][] xAug = new double[x.length][nVars + 1];
177                for (int i = 0; i < x.length; i++) {
178                    if (x[i].length != nVars) {
179                        throw MathRuntimeException.createIllegalArgumentException(
180                                LocalizedFormats.DIFFERENT_ROWS_LENGTHS,
181                                x[i].length, nVars);
182                    }
183                    xAug[i][0] = 1.0d;
184                    System.arraycopy(x[i], 0, xAug[i], 1, nVars);
185                }
186                this.X = new Array2DRowRealMatrix(xAug, false);
187            }
188        }
189    
190        /**
191         * Validates sample data.  Checks that
192         * <ul><li>Neither x nor y is null or empty;</li>
193         * <li>The length (i.e. number of rows) of x equals the length of y</li>
194         * <li>x has at least one more row than it has columns (i.e. there is
195         * sufficient data to estimate regression coefficients for each of the
196         * columns in x plus an intercept.</li>
197         * </ul>
198         *
199         * @param x the [n,k] array representing the x data
200         * @param y the [n,1] array representing the y data
201         * @throws IllegalArgumentException if any of the checks fail
202         *
203         */
204        protected void validateSampleData(double[][] x, double[] y) {
205            if ((x == null) || (y == null) || (x.length != y.length)) {
206                throw MathRuntimeException.createIllegalArgumentException(
207                      LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE,
208                      (x == null) ? 0 : x.length,
209                      (y == null) ? 0 : y.length);
210            }
211            if (x.length == 0) {  // Must be no y data either
212                throw MathRuntimeException.createIllegalArgumentException(
213                        LocalizedFormats.NO_DATA);
214            }
215            if (x[0].length + 1 > x.length) {
216                throw MathRuntimeException.createIllegalArgumentException(
217                      LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
218                      x.length, x[0].length);
219            }
220        }
221    
222        /**
223         * Validates that the x data and covariance matrix have the same
224         * number of rows and that the covariance matrix is square.
225         *
226         * @param x the [n,k] array representing the x sample
227         * @param covariance the [n,n] array representing the covariance matrix
228         * @throws IllegalArgumentException if the number of rows in x is not equal
229         * to the number of rows in covariance or covariance is not square.
230         */
231        protected void validateCovarianceData(double[][] x, double[][] covariance) {
232            if (x.length != covariance.length) {
233                throw MathRuntimeException.createIllegalArgumentException(
234                     LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, x.length, covariance.length);
235            }
236            if (covariance.length > 0 && covariance.length != covariance[0].length) {
237                throw MathRuntimeException.createIllegalArgumentException(
238                      LocalizedFormats.NON_SQUARE_MATRIX,
239                      covariance.length, covariance[0].length);
240            }
241        }
242    
243        /**
244         * {@inheritDoc}
245         */
246        public double[] estimateRegressionParameters() {
247            RealVector b = calculateBeta();
248            return b.getData();
249        }
250    
251        /**
252         * {@inheritDoc}
253         */
254        public double[] estimateResiduals() {
255            RealVector b = calculateBeta();
256            RealVector e = Y.subtract(X.operate(b));
257            return e.getData();
258        }
259    
260        /**
261         * {@inheritDoc}
262         */
263        public double[][] estimateRegressionParametersVariance() {
264            return calculateBetaVariance().getData();
265        }
266    
267        /**
268         * {@inheritDoc}
269         */
270        public double[] estimateRegressionParametersStandardErrors() {
271            double[][] betaVariance = estimateRegressionParametersVariance();
272            double sigma = calculateErrorVariance();
273            int length = betaVariance[0].length;
274            double[] result = new double[length];
275            for (int i = 0; i < length; i++) {
276                result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
277            }
278            return result;
279        }
280    
281        /**
282         * {@inheritDoc}
283         */
284        public double estimateRegressandVariance() {
285            return calculateYVariance();
286        }
287    
288        /**
289         * Estimates the variance of the error.
290         *
291         * @return estimate of the error variance
292         * @since 2.2
293         */
294        public double estimateErrorVariance() {
295            return calculateErrorVariance();
296    
297        }
298    
299        /**
300         * Estimates the standard error of the regression.
301         *
302         * @return regression standard error
303         * @since 2.2
304         */
305        public double estimateRegressionStandardError() {
306            return Math.sqrt(estimateErrorVariance());
307        }
308    
309        /**
310         * Calculates the beta of multiple linear regression in matrix notation.
311         *
312         * @return beta
313         */
314        protected abstract RealVector calculateBeta();
315    
316        /**
317         * Calculates the beta variance of multiple linear regression in matrix
318         * notation.
319         *
320         * @return beta variance
321         */
322        protected abstract RealMatrix calculateBetaVariance();
323    
324    
325        /**
326         * Calculates the variance of the y values.
327         *
328         * @return Y variance
329         */
330        protected double calculateYVariance() {
331            return new Variance().evaluate(Y.getData());
332        }
333    
334        /**
335         * <p>Calculates the variance of the error term.</p>
336         * Uses the formula <pre>
337         * var(u) = u &middot; u / (n - k)
338         * </pre>
339         * where n and k are the row and column dimensions of the design
340         * matrix X.
341         *
342         * @return error variance estimate
343         * @since 2.2
344         */
345        protected double calculateErrorVariance() {
346            RealVector residuals = calculateResiduals();
347            return residuals.dotProduct(residuals) /
348                   (X.getRowDimension() - X.getColumnDimension());
349        }
350    
351        /**
352         * Calculates the residuals of multiple linear regression in matrix
353         * notation.
354         *
355         * <pre>
356         * u = y - X * b
357         * </pre>
358         *
359         * @return The residuals [n,1] matrix
360         */
361        protected RealVector calculateResiduals() {
362            RealVector b = calculateBeta();
363            return Y.subtract(X.operate(b));
364        }
365    
366    }