public abstract class NeuralNetworkOptimizer extends Object implements cc.mallet.optimize.Optimizable.ByGradientValue, Serializable, NeuralNetEpochListener
cc.mallet.optimize.Optimizable.ByBatchGradient, cc.mallet.optimize.Optimizable.ByCombiningBatchGradient, cc.mallet.optimize.Optimizable.ByGISUpdate, cc.mallet.optimize.Optimizable.ByGradient, cc.mallet.optimize.Optimizable.ByGradientValue, cc.mallet.optimize.Optimizable.ByHessian, cc.mallet.optimize.Optimizable.ByValue, cc.mallet.optimize.Optimizable.ByVotedPerceptron
Modifier and Type | Field and Description |
---|---|
protected List<Double> |
errors |
protected Object[] |
extraParams |
protected static org.slf4j.Logger |
log |
protected double |
lr |
protected double |
minLearningRate |
protected BaseNeuralNetwork |
network |
protected NonZeroStoppingConjugateGradient |
opt |
protected double |
tolerance |
Constructor and Description |
---|
NeuralNetworkOptimizer(BaseNeuralNetwork network,
double lr,
Object[] trainingParams) |
Modifier and Type | Method and Description |
---|---|
void |
epochDone(int epoch) |
List<Double> |
getErrors() |
int |
getNumParameters() |
double |
getParameter(int index) |
void |
getParameters(double[] buffer) |
double |
getValue() |
abstract void |
getValueGradient(double[] buffer) |
void |
setParameter(int index,
double value) |
void |
setParameters(double[] params) |
void |
train(org.jblas.DoubleMatrix x) |
protected BaseNeuralNetwork network
protected double lr
protected Object[] extraParams
protected double tolerance
protected static org.slf4j.Logger log
protected double minLearningRate
protected transient NonZeroStoppingConjugateGradient opt
public NeuralNetworkOptimizer(BaseNeuralNetwork network, double lr, Object[] trainingParams)
public void train(org.jblas.DoubleMatrix x)
public void epochDone(int epoch)
epochDone
in interface NeuralNetEpochListener
public int getNumParameters()
getNumParameters
in interface cc.mallet.optimize.Optimizable
public void getParameters(double[] buffer)
getParameters
in interface cc.mallet.optimize.Optimizable
public double getParameter(int index)
getParameter
in interface cc.mallet.optimize.Optimizable
public void setParameters(double[] params)
setParameters
in interface cc.mallet.optimize.Optimizable
public void setParameter(int index, double value)
setParameter
in interface cc.mallet.optimize.Optimizable
public abstract void getValueGradient(double[] buffer)
getValueGradient
in interface cc.mallet.optimize.Optimizable.ByGradientValue
public double getValue()
getValue
in interface cc.mallet.optimize.Optimizable.ByGradientValue
Copyright © 2014. All Rights Reserved.