public class NeuralNetwork extends java.lang.Object implements OnlineRegression<double[]>, java.io.Serializable
Modifier and Type | Class and Description |
---|---|
static class |
NeuralNetwork.ActivationFunction |
static class |
NeuralNetwork.Trainer
Trainer for neural networks.
|
Constructor and Description |
---|
NeuralNetwork(int... numUnits)
Constructor.
|
NeuralNetwork(NeuralNetwork.ActivationFunction activation,
double alpha,
double lambda,
int... numUnits)
Constructor.
|
NeuralNetwork(NeuralNetwork.ActivationFunction activation,
int... numUnits)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
NeuralNetwork |
clone() |
double |
getLearningRate()
Returns the learning rate.
|
double |
getMomentum()
Returns the momentum factor.
|
double[][] |
getWeight(int layer)
Returns the weights of a layer.
|
double |
getWeightDecay()
Returns the weight decay factor.
|
void |
learn(double[][] x,
double[] y)
Trains the neural network with the given dataset for one epoch by
stochastic gradient descent.
|
void |
learn(double[] x,
double y)
Online update the regression model with a new training instance.
|
double |
learn(double[] x,
double y,
double weight)
Update the neural network with given instance and associated target value.
|
double |
predict(double[] x)
Predicts the dependent variable of an instance.
|
void |
setLearningRate(double eta)
Sets the learning rate.
|
void |
setMomentum(double alpha)
Sets the momentum factor.
|
void |
setWeightDecay(double lambda)
Sets the weight decay factor.
|
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
predict
public NeuralNetwork(int... numUnits)
numUnits
- the number of units in each layer.public NeuralNetwork(NeuralNetwork.ActivationFunction activation, int... numUnits)
activation
- the activation function of output layer.numUnits
- the number of units in each layer.public NeuralNetwork(NeuralNetwork.ActivationFunction activation, double alpha, double lambda, int... numUnits)
activation
- the activation function of output layer.numUnits
- the number of units in each layer.public NeuralNetwork clone()
clone
in class java.lang.Object
public void setLearningRate(double eta)
eta
- the learning rate.public double getLearningRate()
public void setMomentum(double alpha)
alpha
- the momentum factor.public double getMomentum()
public void setWeightDecay(double lambda)
lambda
- the weight decay for regularization.public double getWeightDecay()
public double[][] getWeight(int layer)
layer
- the layer of netural network, 0 for input layer.public double predict(double[] x)
Regression
predict
in interface Regression<double[]>
x
- the instance.public double learn(double[] x, double y, double weight)
x
- the training instance.y
- the target value.weight
- a positive weight value associated with the training instance.public void learn(double[] x, double y)
OnlineRegression
learn
in interface OnlineRegression<double[]>
x
- training instance.y
- response variable.public void learn(double[][] x, double[] y)
x
- training instances.y
- training labels in [0, k), where k is the number of classes.