public class LossFunctionWrapper extends Object implements ReconstructionDistribution
ReconstructionDistribution
as would normally be done with a VAE model.
Note: most functionality is supported, but clearly reconstruction log probability cannot be calculated when using LossFunctionWrapper, as ILossFunction instances do not have either (a) a probabilistic interpretation, or (b) a means of calculating the negative log probability.
Constructor and Description |
---|
LossFunctionWrapper(Activation activation,
ILossFunction lossFunction) |
LossFunctionWrapper(IActivation activationFn,
ILossFunction lossFunction) |
Modifier and Type | Method and Description |
---|---|
int |
distributionInputSize(int dataSize)
Get the number of distribution parameters for the given input data size.
|
INDArray |
exampleNegLogProbability(INDArray x,
INDArray preOutDistributionParams)
Calculate the negative log probability for each example individually
|
INDArray |
generateAtMean(INDArray preOutDistributionParams)
Generate a sample from P(x|z), where x = E[P(x|z)]
i.e., return the mean value for the distribution
|
INDArray |
generateRandom(INDArray preOutDistributionParams)
Randomly sample from P(x|z) using the specified distribution parameters
|
INDArray |
gradient(INDArray x,
INDArray preOutDistributionParams)
Calculate the gradient of the negative log probability with respect to the preOutDistributionParams
|
boolean |
hasLossFunction()
Does this reconstruction distribution has a standard neural network loss function (such as mean squared error,
which is deterministic) or is it a standard VAE with a probabilistic reconstruction distribution?
|
double |
negLogProbability(INDArray x,
INDArray preOutDistributionParams,
boolean average)
Calculate the negative log probability (summed or averaged over each example in the minibatch)
|
String |
toString() |
public LossFunctionWrapper(IActivation activationFn, ILossFunction lossFunction)
public LossFunctionWrapper(Activation activation, ILossFunction lossFunction)
public boolean hasLossFunction()
ReconstructionDistribution
hasLossFunction
in interface ReconstructionDistribution
public int distributionInputSize(int dataSize)
ReconstructionDistribution
distributionInputSize
in interface ReconstructionDistribution
dataSize
- Size of the data. i.e., nIn valuepublic double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average)
ReconstructionDistribution
negLogProbability
in interface ReconstructionDistribution
x
- Data to be modelled (reconstructions)preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian)average
- Whether the log probability should be averaged over the minibatch, or simply summed.public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams)
ReconstructionDistribution
exampleNegLogProbability
in interface ReconstructionDistribution
x
- Data to be modelled (reconstructions)preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic INDArray gradient(INDArray x, INDArray preOutDistributionParams)
ReconstructionDistribution
gradient
in interface ReconstructionDistribution
x
- DatapreOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic INDArray generateRandom(INDArray preOutDistributionParams)
ReconstructionDistribution
generateRandom
in interface ReconstructionDistribution
preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionpublic INDArray generateAtMean(INDArray preOutDistributionParams)
ReconstructionDistribution
generateAtMean
in interface ReconstructionDistribution
preOutDistributionParams
- Distribution parameters used by this reconstruction distribution (for example,
mean and log variance values for Gaussian) - before applying activation functionCopyright © 2019. All rights reserved.