public abstract class BaseNeuralNetwork extends Object implements NeuralNetwork, Persistable
DBN
Modifier and Type | Class and Description |
---|---|
static class |
BaseNeuralNetwork.Builder<E extends BaseNeuralNetwork> |
Modifier and Type | Field and Description |
---|---|
org.apache.commons.math3.distribution.RealDistribution |
dist |
double |
fanIn |
org.jblas.DoubleMatrix |
hBias |
org.jblas.DoubleMatrix |
input |
double |
l2 |
double |
momentum |
int |
nHidden
Number of hidden units
One tip with this is usually having
more hidden units than inputs (read: input rows here)
will typically cause terrible overfitting.
|
int |
nVisible |
NeuralNetworkOptimizer |
optimizer |
int |
renderWeightsEveryNumEpochs |
org.apache.commons.math3.random.RandomGenerator |
rng |
double |
sparsity |
boolean |
useRegularization |
org.jblas.DoubleMatrix |
vBias |
org.jblas.DoubleMatrix |
W |
Constructor and Description |
---|
BaseNeuralNetwork() |
BaseNeuralNetwork(org.jblas.DoubleMatrix input,
int nVisible,
int nHidden,
org.jblas.DoubleMatrix W,
org.jblas.DoubleMatrix hbias,
org.jblas.DoubleMatrix vbias,
org.apache.commons.math3.random.RandomGenerator rng,
double fanIn,
org.apache.commons.math3.distribution.RealDistribution dist) |
BaseNeuralNetwork(int nVisible,
int nHidden,
org.jblas.DoubleMatrix W,
org.jblas.DoubleMatrix hbias,
org.jblas.DoubleMatrix vbias,
org.apache.commons.math3.random.RandomGenerator rng,
double fanIn,
org.apache.commons.math3.distribution.RealDistribution dist) |
Modifier and Type | Method and Description |
---|---|
NeuralNetwork |
clone() |
double |
fanIn() |
org.apache.commons.math3.distribution.RealDistribution |
getDist() |
org.jblas.DoubleMatrix |
gethBias() |
org.jblas.DoubleMatrix |
getInput() |
double |
getL2() |
double |
getMomentum() |
int |
getnHidden() |
int |
getnVisible() |
double |
getReConstructionCrossEntropy()
Reconstruction entropy.
|
int |
getRenderEpochs() |
org.apache.commons.math3.random.RandomGenerator |
getRng() |
double |
getSparsity() |
org.jblas.DoubleMatrix |
getvBias() |
org.jblas.DoubleMatrix |
getW() |
protected void |
initWeights()
Initialize weights.
|
void |
jostleWeighMatrix() |
double |
l2RegularizedCoefficient() |
void |
load(InputStream is)
Load (using
ObjectInputStream |
double |
lossFunction() |
abstract double |
lossFunction(Object[] params)
The loss function (cross entropy, reconstruction error,...)
|
void |
merge(NeuralNetwork network,
int batchSize)
Performs a network merge in the form of
a += b - a / n
where a is a matrix here
b is a matrix on the incoming network
and n is the batch size
|
abstract org.jblas.DoubleMatrix |
reconstruct(org.jblas.DoubleMatrix x)
All neural networks are based on this idea of
minimizing reconstruction error.
|
void |
setDist(org.apache.commons.math3.distribution.RealDistribution dist) |
void |
setFanIn(double fanIn) |
void |
sethBias(org.jblas.DoubleMatrix hBias) |
void |
setInput(org.jblas.DoubleMatrix input) |
void |
setL2(double l2) |
void |
setMomentum(double momentum) |
void |
setnHidden(int nHidden) |
void |
setnVisible(int nVisible) |
void |
setRenderEpochs(int renderEpochs) |
void |
setRng(org.apache.commons.math3.random.RandomGenerator rng) |
void |
setSparsity(double sparsity) |
void |
setvBias(org.jblas.DoubleMatrix vBias) |
void |
setW(org.jblas.DoubleMatrix w) |
double |
squaredLoss() |
abstract void |
train(org.jblas.DoubleMatrix input,
double lr,
Object[] params)
Train one iteration of the network
|
NeuralNetwork |
transpose() |
void |
update(BaseNeuralNetwork n)
Copies params from the passed in network
to this one
|
void |
write(OutputStream os)
Write this to an object output stream
|
equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
getGradient, trainTillConvergence
public int nVisible
public int nHidden
public org.jblas.DoubleMatrix W
public org.jblas.DoubleMatrix hBias
public org.jblas.DoubleMatrix vBias
public org.apache.commons.math3.random.RandomGenerator rng
public org.jblas.DoubleMatrix input
public double sparsity
public double momentum
public transient org.apache.commons.math3.distribution.RealDistribution dist
public double l2
public transient NeuralNetworkOptimizer optimizer
public int renderWeightsEveryNumEpochs
public double fanIn
public boolean useRegularization
public BaseNeuralNetwork()
public BaseNeuralNetwork(int nVisible, int nHidden, org.jblas.DoubleMatrix W, org.jblas.DoubleMatrix hbias, org.jblas.DoubleMatrix vbias, org.apache.commons.math3.random.RandomGenerator rng, double fanIn, org.apache.commons.math3.distribution.RealDistribution dist)
nVisible
- the number of outbound nodesnHidden
- the number of nodes in the hidden layerW
- the weights for this vector, maybe null, if so this will
create a matrix with nHidden x nVisible dimensions.hBias
- the hidden biasvBias
- the visible bias (usually b for the output layer)rng
- the rng, if not a seed of 1234 is used.public BaseNeuralNetwork(org.jblas.DoubleMatrix input, int nVisible, int nHidden, org.jblas.DoubleMatrix W, org.jblas.DoubleMatrix hbias, org.jblas.DoubleMatrix vbias, org.apache.commons.math3.random.RandomGenerator rng, double fanIn, org.apache.commons.math3.distribution.RealDistribution dist)
input
- the input examplesnVisible
- the number of outbound nodesnHidden
- the number of nodes in the hidden layerW
- the weights for this vector, maybe null, if so this will
create a matrix with nHidden x nVisible dimensions.hBias
- the hidden biasvBias
- the visible bias (usually b for the output layer)rng
- the rng, if not a seed of 1234 is used.public double l2RegularizedCoefficient()
l2RegularizedCoefficient
in interface NeuralNetwork
protected void initWeights()
public void setRenderEpochs(int renderEpochs)
setRenderEpochs
in interface NeuralNetwork
public int getRenderEpochs()
getRenderEpochs
in interface NeuralNetwork
public double fanIn()
fanIn
in interface NeuralNetwork
public void setFanIn(double fanIn)
setFanIn
in interface NeuralNetwork
public void jostleWeighMatrix()
public NeuralNetwork transpose()
transpose
in interface NeuralNetwork
public NeuralNetwork clone()
clone
in interface NeuralNetwork
clone
in class Object
public org.apache.commons.math3.distribution.RealDistribution getDist()
getDist
in interface NeuralNetwork
public void setDist(org.apache.commons.math3.distribution.RealDistribution dist)
setDist
in interface NeuralNetwork
public void merge(NeuralNetwork network, int batchSize)
NeuralNetwork
merge
in interface NeuralNetwork
network
- the network to merge withbatchSize
- the batch size (number of training examples)
to average bypublic void update(BaseNeuralNetwork n)
n
- the network to copypublic void load(InputStream is)
ObjectInputStream
load
in interface Persistable
is
- the input stream to load from (usually a file)public double getReConstructionCrossEntropy()
getReConstructionCrossEntropy
in interface NeuralNetwork
public int getnVisible()
getnVisible
in interface NeuralNetwork
public void setnVisible(int nVisible)
setnVisible
in interface NeuralNetwork
public int getnHidden()
getnHidden
in interface NeuralNetwork
public void setnHidden(int nHidden)
setnHidden
in interface NeuralNetwork
public org.jblas.DoubleMatrix getW()
getW
in interface NeuralNetwork
public void setW(org.jblas.DoubleMatrix w)
setW
in interface NeuralNetwork
public org.jblas.DoubleMatrix gethBias()
gethBias
in interface NeuralNetwork
public void sethBias(org.jblas.DoubleMatrix hBias)
sethBias
in interface NeuralNetwork
public org.jblas.DoubleMatrix getvBias()
getvBias
in interface NeuralNetwork
public void setvBias(org.jblas.DoubleMatrix vBias)
setvBias
in interface NeuralNetwork
public org.apache.commons.math3.random.RandomGenerator getRng()
getRng
in interface NeuralNetwork
public void setRng(org.apache.commons.math3.random.RandomGenerator rng)
setRng
in interface NeuralNetwork
public org.jblas.DoubleMatrix getInput()
getInput
in interface NeuralNetwork
public void setInput(org.jblas.DoubleMatrix input)
setInput
in interface NeuralNetwork
public double getSparsity()
getSparsity
in interface NeuralNetwork
public void setSparsity(double sparsity)
setSparsity
in interface NeuralNetwork
public double getMomentum()
getMomentum
in interface NeuralNetwork
public void setMomentum(double momentum)
setMomentum
in interface NeuralNetwork
public double getL2()
getL2
in interface NeuralNetwork
public void setL2(double l2)
setL2
in interface NeuralNetwork
public void write(OutputStream os)
write
in interface Persistable
os
- the output stream to write topublic abstract org.jblas.DoubleMatrix reconstruct(org.jblas.DoubleMatrix x)
x
- the input to reconstructpublic abstract double lossFunction(Object[] params)
public double lossFunction()
public abstract void train(org.jblas.DoubleMatrix input, double lr, Object[] params)
train
in interface NeuralNetwork
input
- the input to train onlr
- the learning rate to train atparams
- the extra params (k, corruption level,...)public double squaredLoss()
squaredLoss
in interface NeuralNetwork
Copyright © 2014. All Rights Reserved.