public class MultivariateGaussianVertex extends DoubleVertex implements Differentiable, ProbabilisticDouble, SamplableWithManyScalars<DoubleTensor>, LogProbGraphSupplier
Constructor and Description |
---|
MultivariateGaussianVertex(double mu,
double covariance) |
MultivariateGaussianVertex(DoubleVertex mu,
double covariance)
Matches a mu to a Multivariate Gaussian.
|
MultivariateGaussianVertex(DoubleVertex mu,
DoubleVertex covariance)
Matches a mu and covariance of some shape to a Multivariate Gaussian
|
MultivariateGaussianVertex(long[] shape,
DoubleVertex mu,
DoubleVertex covariance)
Multivariate gaussian distribution.
|
Modifier and Type | Method and Description |
---|---|
java.util.Map<Vertex,DoubleTensor> |
dLogProb(DoubleTensor value,
java.util.Set<? extends Vertex> withRespectTo)
The partial derivatives of the natural log prob.
|
DoubleVertex |
getCovariance() |
DoubleVertex |
getMu() |
double |
logProb(DoubleTensor value)
This is the natural log of the probability at the supplied value.
|
LogProbGraph |
logProbGraph() |
DoubleTensor |
sampleWithShape(long[] shape,
KeanuRandom random) |
abs, acos, asin, atan, atan2, ceil, concat, cos, div, div, divideBy, divideBy, equalTo, exp, floor, getValue, greaterThan, greaterThanOrEqualTo, lambda, lambda, lessThan, lessThanOrEqualTo, loadValue, log, logGamma, matrixDeterminant, matrixInverse, matrixMultiply, max, min, minus, minus, multiply, multiply, notEqualTo, observe, observe, permute, plus, plus, pow, pow, reshape, reverseDiv, reverseMinus, round, saveValue, setAndCascade, setAndCascade, setValue, setValue, setWithMask, setWithMask, sigmoid, sin, slice, sum, sum, take, tan, times, times, toGreaterThanMask, toGreaterThanMask, toGreaterThanOrEqualToMask, toGreaterThanOrEqualToMask, toInteger, toLessThanMask, toLessThanMask, toLessThanOrEqualToMask, toLessThanOrEqualToMask, unaryMinus
addChild, addParent, addParents, equals, eval, getChildren, getConnectedGraph, getDegree, getId, getIndentation, getLabel, getObservedValue, getParents, getRank, getReference, getShape, getState, getValue, hashCode, hasValue, isDifferentiable, isObserved, isProbabilistic, lazyEval, observe, observeOwnValue, print, print, removeLabel, save, setAndCascade, setLabel, setLabel, setParents, setParents, setState, setValue, toString, unobserve
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
forwardModeAutoDifferentiation, reverseModeAutoDifferentiation, withRespectToSelf
dLogPdf, dLogPdf, dLogPdf, dLogPdf, dLogPdf, dLogPdf, logPdf, logPdf, logPdf
dLogProb, dLogProbAtValue, dLogProbAtValue, getValue, keepOnlyProbabilisticVertices, logProbAtValue
getObservedValue, isObserved, observe, unobserve
sample, sampleManyScalars, sampleManyScalars
sampleWithShape
public MultivariateGaussianVertex(long[] shape, DoubleVertex mu, DoubleVertex covariance)
shape
- the desired shape of the vertexmu
- the mu of the Multivariate Gaussiancovariance
- the covariance matrix of the Multivariate Gaussianpublic MultivariateGaussianVertex(DoubleVertex mu, DoubleVertex covariance)
mu
- the mu of the Multivariate Gaussiancovariance
- the covariance matrix of the Multivariate Gaussianpublic MultivariateGaussianVertex(DoubleVertex mu, double covariance)
mu
- the mu of the Multivariate Gaussiancovariance
- the scale of the identity matrixpublic MultivariateGaussianVertex(double mu, double covariance)
public DoubleVertex getMu()
public DoubleVertex getCovariance()
public double logProb(DoubleTensor value)
Probabilistic
logProb
in interface Probabilistic<DoubleTensor>
value
- The supplied value.public LogProbGraph logProbGraph()
logProbGraph
in interface LogProbGraphSupplier
public java.util.Map<Vertex,DoubleTensor> dLogProb(DoubleTensor value, java.util.Set<? extends Vertex> withRespectTo)
Probabilistic
dLogProb
in interface Probabilistic<DoubleTensor>
value
- at a given valuewithRespectTo
- list of parents to differentiate with respect topublic DoubleTensor sampleWithShape(long[] shape, KeanuRandom random)
sampleWithShape
in interface SamplableWithShape<DoubleTensor>