|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization
public class MultivariateNormalMixtureExpectationMaximization
Expectation-Maximization algorithm for fitting the parameters of multivariate normal mixture model distributions. This implementation is pure original code based on EM Demystified: An Expectation-Maximization Tutorial by Yihua Chen and Maya R. Gupta, Department of Electrical Engineering, University of Washington, Seattle, WA 98195. It was verified using external tools like CRAN Mixtools (see the JUnit test cases) but it is not based on Mixtools code at all. The discussion of the origin of this class can be seen in the comments of the MATH-817 JIRA issue.
Constructor Summary | |
---|---|
MultivariateNormalMixtureExpectationMaximization(double[][] data)
Creates an object to fit a multivariate normal mixture model to data. |
Method Summary | |
---|---|
static MixtureMultivariateNormalDistribution |
estimate(double[][] data,
int numComponents)
Helper method to create a multivariate normal mixture model which can be used to initialize fit(MixtureMultivariateNormalDistribution) . |
void |
fit(MixtureMultivariateNormalDistribution initialMixture)
Fit a mixture model to the data supplied to the constructor. |
void |
fit(MixtureMultivariateNormalDistribution initialMixture,
int maxIterations,
double threshold)
Fit a mixture model to the data supplied to the constructor. |
MixtureMultivariateNormalDistribution |
getFittedModel()
Gets the fitted model. |
double |
getLogLikelihood()
Gets the log likelihood of the data under the fitted model. |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Constructor Detail |
---|
public MultivariateNormalMixtureExpectationMaximization(double[][] data) throws NotStrictlyPositiveException, DimensionMismatchException, NumberIsTooSmallException
data
- Data to use in fitting procedure
NotStrictlyPositiveException
- if data has no rows
DimensionMismatchException
- if rows of data have different numbers
of columns
NumberIsTooSmallException
- if the number of columns in the data is
less than 2Method Detail |
---|
public void fit(MixtureMultivariateNormalDistribution initialMixture, int maxIterations, double threshold) throws SingularMatrixException, NotStrictlyPositiveException, DimensionMismatchException
initialMixture
- Model containing initial values of weights and
multivariate normalsmaxIterations
- Maximum iterations allowed for fitthreshold
- Convergence threshold computed as difference in
logLikelihoods between successive iterations
SingularMatrixException
- if any component's covariance matrix is
singular during fitting
NotStrictlyPositiveException
- if numComponents is less than one
or threshold is less than Double.MIN_VALUE
DimensionMismatchException
- if initialMixture mean vector and data
number of columns are not equalpublic void fit(MixtureMultivariateNormalDistribution initialMixture) throws SingularMatrixException, NotStrictlyPositiveException
initialMixture
- Model containing initial values of weights and
multivariate normals
SingularMatrixException
- if any component's covariance matrix is
singular during fitting
NotStrictlyPositiveException
- if numComponents is less than one or
threshold is less than Double.MIN_VALUEpublic static MixtureMultivariateNormalDistribution estimate(double[][] data, int numComponents) throws NotStrictlyPositiveException, DimensionMismatchException
fit(MixtureMultivariateNormalDistribution)
.
This method uses the data supplied to the constructor to try to determine
a good mixture model at which to start the fit, but it is not guaranteed
to supply a model which will find the optimal solution or even converge.
data
- Data to estimate distributionnumComponents
- Number of components for estimated mixture
NumberIsTooLargeException
- if numComponents
is greater
than the number of data rows.
NumberIsTooSmallException
- if numComponents < 2
.
NotStrictlyPositiveException
- if data has less than 2 rows
DimensionMismatchException
- if rows of data have different numbers
of columnspublic double getLogLikelihood()
public MixtureMultivariateNormalDistribution getFittedModel()
null
if no fit has been performed yet.
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |