001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math3.stat.regression; 018 019 import org.apache.commons.math3.exception.DimensionMismatchException; 020 import org.apache.commons.math3.exception.MathIllegalArgumentException; 021 import org.apache.commons.math3.exception.NoDataException; 022 import org.apache.commons.math3.exception.NullArgumentException; 023 import org.apache.commons.math3.exception.NumberIsTooSmallException; 024 import org.apache.commons.math3.exception.util.LocalizedFormats; 025 import org.apache.commons.math3.linear.NonSquareMatrixException; 026 import org.apache.commons.math3.linear.RealMatrix; 027 import org.apache.commons.math3.linear.Array2DRowRealMatrix; 028 import org.apache.commons.math3.linear.RealVector; 029 import org.apache.commons.math3.linear.ArrayRealVector; 030 import org.apache.commons.math3.stat.descriptive.moment.Variance; 031 import org.apache.commons.math3.util.FastMath; 032 033 /** 034 * Abstract base class for implementations of MultipleLinearRegression. 035 * @version $Id: AbstractMultipleLinearRegression.java 1416643 2012-12-03 19:37:14Z tn $ 036 * @since 2.0 037 */ 038 public abstract class AbstractMultipleLinearRegression implements 039 MultipleLinearRegression { 040 041 /** X sample data. */ 042 private RealMatrix xMatrix; 043 044 /** Y sample data. */ 045 private RealVector yVector; 046 047 /** Whether or not the regression model includes an intercept. True means no intercept. */ 048 private boolean noIntercept = false; 049 050 /** 051 * @return the X sample data. 052 */ 053 protected RealMatrix getX() { 054 return xMatrix; 055 } 056 057 /** 058 * @return the Y sample data. 059 */ 060 protected RealVector getY() { 061 return yVector; 062 } 063 064 /** 065 * @return true if the model has no intercept term; false otherwise 066 * @since 2.2 067 */ 068 public boolean isNoIntercept() { 069 return noIntercept; 070 } 071 072 /** 073 * @param noIntercept true means the model is to be estimated without an intercept term 074 * @since 2.2 075 */ 076 public void setNoIntercept(boolean noIntercept) { 077 this.noIntercept = noIntercept; 078 } 079 080 /** 081 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample. 082 * </p> 083 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input 084 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with 085 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two 086 * independent variables, as below: 087 * <pre> 088 * y x[0] x[1] 089 * -------------- 090 * 1 2 3 091 * 4 5 6 092 * 7 8 9 093 * </pre> 094 * </p> 095 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 096 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>, 097 * the X matrix will be created without an initial column of "1"s; otherwise this column will 098 * be added. 099 * </p> 100 * <p>Throws IllegalArgumentException if any of the following preconditions fail: 101 * <ul><li><code>data</code> cannot be null</li> 102 * <li><code>data.length = nobs * (nvars + 1)</li> 103 * <li><code>nobs > nvars</code></li></ul> 104 * </p> 105 * 106 * @param data input data array 107 * @param nobs number of observations (rows) 108 * @param nvars number of independent variables (columns, not counting y) 109 * @throws NullArgumentException if the data array is null 110 * @throws DimensionMismatchException if the length of the data array is not equal 111 * to <code>nobs * (nvars + 1)</code> 112 * @throws NumberIsTooSmallException if <code>nobs</code> is smaller than 113 * <code>nvars</code> 114 */ 115 public void newSampleData(double[] data, int nobs, int nvars) { 116 if (data == null) { 117 throw new NullArgumentException(); 118 } 119 if (data.length != nobs * (nvars + 1)) { 120 throw new DimensionMismatchException(data.length, nobs * (nvars + 1)); 121 } 122 if (nobs <= nvars) { 123 throw new NumberIsTooSmallException(nobs, nvars, false); 124 } 125 double[] y = new double[nobs]; 126 final int cols = noIntercept ? nvars: nvars + 1; 127 double[][] x = new double[nobs][cols]; 128 int pointer = 0; 129 for (int i = 0; i < nobs; i++) { 130 y[i] = data[pointer++]; 131 if (!noIntercept) { 132 x[i][0] = 1.0d; 133 } 134 for (int j = noIntercept ? 0 : 1; j < cols; j++) { 135 x[i][j] = data[pointer++]; 136 } 137 } 138 this.xMatrix = new Array2DRowRealMatrix(x); 139 this.yVector = new ArrayRealVector(y); 140 } 141 142 /** 143 * Loads new y sample data, overriding any previous data. 144 * 145 * @param y the array representing the y sample 146 * @throws NullArgumentException if y is null 147 * @throws NoDataException if y is empty 148 */ 149 protected void newYSampleData(double[] y) { 150 if (y == null) { 151 throw new NullArgumentException(); 152 } 153 if (y.length == 0) { 154 throw new NoDataException(); 155 } 156 this.yVector = new ArrayRealVector(y); 157 } 158 159 /** 160 * <p>Loads new x sample data, overriding any previous data. 161 * </p> 162 * The input <code>x</code> array should have one row for each sample 163 * observation, with columns corresponding to independent variables. 164 * For example, if <pre> 165 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre> 166 * then <code>setXSampleData(x) </code> results in a model with two independent 167 * variables and 3 observations: 168 * <pre> 169 * x[0] x[1] 170 * ---------- 171 * 1 2 172 * 3 4 173 * 5 6 174 * </pre> 175 * </p> 176 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 177 * specifying a model including an intercept term. 178 * </p> 179 * @param x the rectangular array representing the x sample 180 * @throws NullArgumentException if x is null 181 * @throws NoDataException if x is empty 182 * @throws DimensionMismatchException if x is not rectangular 183 */ 184 protected void newXSampleData(double[][] x) { 185 if (x == null) { 186 throw new NullArgumentException(); 187 } 188 if (x.length == 0) { 189 throw new NoDataException(); 190 } 191 if (noIntercept) { 192 this.xMatrix = new Array2DRowRealMatrix(x, true); 193 } else { // Augment design matrix with initial unitary column 194 final int nVars = x[0].length; 195 final double[][] xAug = new double[x.length][nVars + 1]; 196 for (int i = 0; i < x.length; i++) { 197 if (x[i].length != nVars) { 198 throw new DimensionMismatchException(x[i].length, nVars); 199 } 200 xAug[i][0] = 1.0d; 201 System.arraycopy(x[i], 0, xAug[i], 1, nVars); 202 } 203 this.xMatrix = new Array2DRowRealMatrix(xAug, false); 204 } 205 } 206 207 /** 208 * Validates sample data. Checks that 209 * <ul><li>Neither x nor y is null or empty;</li> 210 * <li>The length (i.e. number of rows) of x equals the length of y</li> 211 * <li>x has at least one more row than it has columns (i.e. there is 212 * sufficient data to estimate regression coefficients for each of the 213 * columns in x plus an intercept.</li> 214 * </ul> 215 * 216 * @param x the [n,k] array representing the x data 217 * @param y the [n,1] array representing the y data 218 * @throws NullArgumentException if {@code x} or {@code y} is null 219 * @throws DimensionMismatchException if {@code x} and {@code y} do not 220 * have the same length 221 * @throws NoDataException if {@code x} or {@code y} are zero-length 222 * @throws MathIllegalArgumentException if the number of rows of {@code x} 223 * is not larger than the number of columns + 1 224 */ 225 protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException { 226 if ((x == null) || (y == null)) { 227 throw new NullArgumentException(); 228 } 229 if (x.length != y.length) { 230 throw new DimensionMismatchException(y.length, x.length); 231 } 232 if (x.length == 0) { // Must be no y data either 233 throw new NoDataException(); 234 } 235 if (x[0].length + 1 > x.length) { 236 throw new MathIllegalArgumentException( 237 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, 238 x.length, x[0].length); 239 } 240 } 241 242 /** 243 * Validates that the x data and covariance matrix have the same 244 * number of rows and that the covariance matrix is square. 245 * 246 * @param x the [n,k] array representing the x sample 247 * @param covariance the [n,n] array representing the covariance matrix 248 * @throws DimensionMismatchException if the number of rows in x is not equal 249 * to the number of rows in covariance 250 * @throws NonSquareMatrixException if the covariance matrix is not square 251 */ 252 protected void validateCovarianceData(double[][] x, double[][] covariance) { 253 if (x.length != covariance.length) { 254 throw new DimensionMismatchException(x.length, covariance.length); 255 } 256 if (covariance.length > 0 && covariance.length != covariance[0].length) { 257 throw new NonSquareMatrixException(covariance.length, covariance[0].length); 258 } 259 } 260 261 /** 262 * {@inheritDoc} 263 */ 264 public double[] estimateRegressionParameters() { 265 RealVector b = calculateBeta(); 266 return b.toArray(); 267 } 268 269 /** 270 * {@inheritDoc} 271 */ 272 public double[] estimateResiduals() { 273 RealVector b = calculateBeta(); 274 RealVector e = yVector.subtract(xMatrix.operate(b)); 275 return e.toArray(); 276 } 277 278 /** 279 * {@inheritDoc} 280 */ 281 public double[][] estimateRegressionParametersVariance() { 282 return calculateBetaVariance().getData(); 283 } 284 285 /** 286 * {@inheritDoc} 287 */ 288 public double[] estimateRegressionParametersStandardErrors() { 289 double[][] betaVariance = estimateRegressionParametersVariance(); 290 double sigma = calculateErrorVariance(); 291 int length = betaVariance[0].length; 292 double[] result = new double[length]; 293 for (int i = 0; i < length; i++) { 294 result[i] = FastMath.sqrt(sigma * betaVariance[i][i]); 295 } 296 return result; 297 } 298 299 /** 300 * {@inheritDoc} 301 */ 302 public double estimateRegressandVariance() { 303 return calculateYVariance(); 304 } 305 306 /** 307 * Estimates the variance of the error. 308 * 309 * @return estimate of the error variance 310 * @since 2.2 311 */ 312 public double estimateErrorVariance() { 313 return calculateErrorVariance(); 314 315 } 316 317 /** 318 * Estimates the standard error of the regression. 319 * 320 * @return regression standard error 321 * @since 2.2 322 */ 323 public double estimateRegressionStandardError() { 324 return Math.sqrt(estimateErrorVariance()); 325 } 326 327 /** 328 * Calculates the beta of multiple linear regression in matrix notation. 329 * 330 * @return beta 331 */ 332 protected abstract RealVector calculateBeta(); 333 334 /** 335 * Calculates the beta variance of multiple linear regression in matrix 336 * notation. 337 * 338 * @return beta variance 339 */ 340 protected abstract RealMatrix calculateBetaVariance(); 341 342 343 /** 344 * Calculates the variance of the y values. 345 * 346 * @return Y variance 347 */ 348 protected double calculateYVariance() { 349 return new Variance().evaluate(yVector.toArray()); 350 } 351 352 /** 353 * <p>Calculates the variance of the error term.</p> 354 * Uses the formula <pre> 355 * var(u) = u · u / (n - k) 356 * </pre> 357 * where n and k are the row and column dimensions of the design 358 * matrix X. 359 * 360 * @return error variance estimate 361 * @since 2.2 362 */ 363 protected double calculateErrorVariance() { 364 RealVector residuals = calculateResiduals(); 365 return residuals.dotProduct(residuals) / 366 (xMatrix.getRowDimension() - xMatrix.getColumnDimension()); 367 } 368 369 /** 370 * Calculates the residuals of multiple linear regression in matrix 371 * notation. 372 * 373 * <pre> 374 * u = y - X * b 375 * </pre> 376 * 377 * @return The residuals [n,1] matrix 378 */ 379 protected RealVector calculateResiduals() { 380 RealVector b = calculateBeta(); 381 return yVector.subtract(xMatrix.operate(b)); 382 } 383 384 }