public abstract class NeuralNetworkOptimizer extends Object implements cc.mallet.optimize.Optimizable.ByGradientValue, OptimizableByGradientValueMatrix, Serializable, NeuralNetEpochListener
cc.mallet.optimize.Optimizable.ByBatchGradient, cc.mallet.optimize.Optimizable.ByCombiningBatchGradient, cc.mallet.optimize.Optimizable.ByGISUpdate, cc.mallet.optimize.Optimizable.ByGradient, cc.mallet.optimize.Optimizable.ByGradientValue, cc.mallet.optimize.Optimizable.ByHessian, cc.mallet.optimize.Optimizable.ByValue, cc.mallet.optimize.Optimizable.ByVotedPerceptron
Modifier and Type | Field and Description |
---|---|
protected List<Double> |
errors |
protected Object[] |
extraParams |
protected static org.slf4j.Logger |
log |
protected NeuralNetwork.LossFunction |
lossFunction |
protected double |
lr |
protected double |
minLearningRate |
protected NeuralNetwork |
network |
protected OptimizerMatrix |
opt |
protected NeuralNetwork.OptimizationAlgorithm |
optimizationAlgorithm |
protected double |
tolerance |
Constructor and Description |
---|
NeuralNetworkOptimizer(NeuralNetwork network,
double lr,
Object[] trainingParams,
NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm,
NeuralNetwork.LossFunction lossFunction) |
Modifier and Type | Method and Description |
---|---|
void |
epochDone(int epoch)
Event listener for each iteration
|
List<Double> |
getErrors() |
int |
getNumParameters() |
double |
getParameter(int index) |
org.jblas.DoubleMatrix |
getParameters() |
void |
getParameters(double[] buffer) |
double |
getTolerance() |
double |
getValue() |
org.jblas.DoubleMatrix |
getValueGradient() |
abstract void |
getValueGradient(double[] buffer) |
void |
setParameter(int index,
double value) |
void |
setParameters(double[] params) |
void |
setParameters(org.jblas.DoubleMatrix params) |
void |
setTolerance(double tolerance) |
void |
train(org.jblas.DoubleMatrix x) |
protected NeuralNetwork network
protected double lr
protected Object[] extraParams
protected double tolerance
protected static org.slf4j.Logger log
protected double minLearningRate
protected transient OptimizerMatrix opt
protected NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm
protected NeuralNetwork.LossFunction lossFunction
public NeuralNetworkOptimizer(NeuralNetwork network, double lr, Object[] trainingParams, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, NeuralNetwork.LossFunction lossFunction)
network
- lr
- trainingParams
- public void train(org.jblas.DoubleMatrix x)
public void epochDone(int epoch)
NeuralNetEpochListener
epochDone
in interface NeuralNetEpochListener
public int getNumParameters()
getNumParameters
in interface cc.mallet.optimize.Optimizable
getNumParameters
in interface OptimizableByGradientValueMatrix
public void getParameters(double[] buffer)
getParameters
in interface cc.mallet.optimize.Optimizable
public double getParameter(int index)
getParameter
in interface cc.mallet.optimize.Optimizable
getParameter
in interface OptimizableByGradientValueMatrix
public void setParameters(double[] params)
setParameters
in interface cc.mallet.optimize.Optimizable
public void setParameter(int index, double value)
setParameter
in interface cc.mallet.optimize.Optimizable
setParameter
in interface OptimizableByGradientValueMatrix
public org.jblas.DoubleMatrix getParameters()
getParameters
in interface OptimizableByGradientValueMatrix
public void setParameters(org.jblas.DoubleMatrix params)
setParameters
in interface OptimizableByGradientValueMatrix
public org.jblas.DoubleMatrix getValueGradient()
getValueGradient
in interface OptimizableByGradientValueMatrix
public abstract void getValueGradient(double[] buffer)
getValueGradient
in interface cc.mallet.optimize.Optimizable.ByGradientValue
public double getValue()
getValue
in interface cc.mallet.optimize.Optimizable.ByGradientValue
getValue
in interface OptimizableByGradientValueMatrix
public double getTolerance()
public void setTolerance(double tolerance)
Copyright © 2014. All Rights Reserved.