public class ComputationGraph extends Object implements Serializable, Model
Modifier and Type | Field and Description |
---|---|
protected ComputationGraphConfiguration |
configuration |
protected org.nd4j.linalg.api.ndarray.INDArray |
flattenedGradients |
protected org.nd4j.linalg.api.ndarray.INDArray |
flattenedParams |
protected Gradient |
gradient |
protected boolean |
initCalled |
protected Layer[] |
layers
A list of layers.
|
protected double |
score |
protected Solver |
solver |
protected int[] |
topologicalOrder
Indexes of graph vertices, in topological order.
|
protected GraphVertex[] |
vertices
All GraphVertex objects in the network.
|
protected Map<String,GraphVertex> |
verticesMap
Map of vertices by name
|
Constructor and Description |
---|
ComputationGraph(ComputationGraphConfiguration configuration) |
Modifier and Type | Method and Description |
---|---|
void |
accumulateScore(double accum)
Sets a rolling tally for the score.
|
void |
applyLearningRateScoreDecay()
Update learningRate using for this model.
|
Gradient |
backpropGradient(org.nd4j.linalg.api.ndarray.INDArray... epsilons)
Calculate the gradient of the network with respect to some external errors.
|
int |
batchSize()
The current inputs batch size
|
protected void |
calcBackpropGradients(boolean truncatedBPTT,
org.nd4j.linalg.api.ndarray.INDArray... externalEpsilons)
Do backprop (gradient calculation)
|
double |
calcL1()
Calculate the L1 regularization term for all layers in the entire network.
|
double |
calcL2()
Calculate the L2 regularization term for all layers in the entire network.
|
void |
clear()
Clear input
|
void |
clearLayerMaskArrays()
Remove the mask arrays from all layers.
See setLayerMaskArrays(INDArray[], INDArray[]) for details on mask arrays. |
ComputationGraph |
clone() |
void |
computeGradientAndScore()
Update the score
|
NeuralNetConfiguration |
conf()
The configuration for the neural network
|
protected void |
doTruncatedBPTT(org.nd4j.linalg.api.ndarray.INDArray[] inputs,
org.nd4j.linalg.api.ndarray.INDArray[] labels,
org.nd4j.linalg.api.ndarray.INDArray[] featureMasks,
org.nd4j.linalg.api.ndarray.INDArray[] labelMasks)
Fit the network using truncated BPTT
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
feedForward()
Conduct forward pass using the stored inputs, at test time
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
feedForward(boolean train)
Conduct forward pass using the stored inputs
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
feedForward(org.nd4j.linalg.api.ndarray.INDArray[] input,
boolean train)
Conduct forward pass using an array of inputs
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
feedForward(org.nd4j.linalg.api.ndarray.INDArray input,
boolean train)
Conduct forward pass using a single input array.
|
void |
fit()
All models have a fit method
|
void |
fit(org.nd4j.linalg.dataset.api.DataSet dataSet)
Fit the ComputationGraph using a DataSet.
|
void |
fit(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator)
Fit the ComputationGraph using a DataSetIterator.
|
void |
fit(org.nd4j.linalg.api.ndarray.INDArray data)
Fit the model to the given data
|
void |
fit(org.nd4j.linalg.api.ndarray.INDArray[] inputs,
org.nd4j.linalg.api.ndarray.INDArray[] labels)
Fit the ComputationGraph given arrays of inputs and labels.
|
void |
fit(org.nd4j.linalg.api.ndarray.INDArray[] inputs,
org.nd4j.linalg.api.ndarray.INDArray[] labels,
org.nd4j.linalg.api.ndarray.INDArray[] featureMaskArrays,
org.nd4j.linalg.api.ndarray.INDArray[] labelMaskArrays)
Fit the ComputationGraph using the specified inputs and labels (and mask arrays)
|
void |
fit(org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet)
Fit the ComputationGraph using a MultiDataSet
|
void |
fit(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator multi)
Fit the ComputationGraph using a MultiDataSetIterator
|
ComputationGraphConfiguration |
getConfiguration() |
org.nd4j.linalg.api.ndarray.INDArray |
getInput(int inputNum)
Get the previously set input for the ComputationGraph
|
org.nd4j.linalg.api.ndarray.INDArray[] |
getInputMaskArrays()
Get the previously set feature/input mask arrays for the ComputationGraph
|
org.nd4j.linalg.api.ndarray.INDArray[] |
getInputs()
Get the previously set inputs for the ComputationGraph
|
org.nd4j.linalg.api.ndarray.INDArray[] |
getLabelMaskArrays()
Get the previously set label/output mask arrays for the ComputationGraph
|
Layer |
getLayer(int idx)
Get the layer by the number of that layer, in range 0 to getNumLayers()-1
NOTE: This is different from the internal GraphVertex index for the layer
|
Layer |
getLayer(String name)
Get a given layer by name.
|
Layer[] |
getLayers()
Get all layers in the ComputationGraph
|
Collection<IterationListener> |
getListeners()
Get the IterationListeners for the ComputationGraph
|
int |
getNumInputArrays()
The number of inputs to this network
|
int |
getNumLayers()
Returns the number of layers in the ComputationGraph
|
int |
getNumOutputArrays()
The number of output (arrays) for this network
|
ConvexOptimizer |
getOptimizer()
Returns this models optimizer
|
Layer |
getOutputLayer(int outputLayerIdx)
Get the specified output layer, by index.
|
org.nd4j.linalg.api.ndarray.INDArray |
getParam(String param)
Get the parameter
|
ComputationGraphUpdater |
getUpdater()
Get the ComputationGraphUpdater for the network
|
GraphVertex |
getVertex(String name)
Return a given GraphVertex by name, or null if no vertex with that name exists
|
GraphVertex[] |
getVertices()
Returns an array of all GraphVertex objects.
|
Gradient |
gradient()
Calculate a gradient
|
Pair<Gradient,Double> |
gradientAndScore()
Get the gradient and score
|
void |
init()
Initialize the ComputationGraph network
|
void |
init(org.nd4j.linalg.api.ndarray.INDArray parameters,
boolean cloneParametersArray)
Initialize the ComputationGraph, optionally with an existing parameters array.
|
void |
initGradientsView()
This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers.
|
void |
initParams()
Initialize the parameters
|
org.nd4j.linalg.api.ndarray.INDArray |
input()
The input/feature matrix for the model
|
void |
iterate(org.nd4j.linalg.api.ndarray.INDArray input)
Run one iteration
|
int |
numParams()
the number of parameters for the model
|
int |
numParams(boolean backwards)
the number of parameters for the model
|
org.nd4j.linalg.api.ndarray.INDArray[] |
output(boolean train,
org.nd4j.linalg.api.ndarray.INDArray... input)
Return an array of network outputs (predictions), given the specified network inputs
Network outputs are for output layers only.
|
org.nd4j.linalg.api.ndarray.INDArray[] |
output(org.nd4j.linalg.api.ndarray.INDArray... input)
Return an array of network outputs (predictions) at test time, given the specified network inputs
Network outputs are for output layers only.
|
org.nd4j.linalg.api.ndarray.INDArray |
outputSingle(boolean train,
org.nd4j.linalg.api.ndarray.INDArray... input)
A convenience method that returns a single INDArray, instead of an INDArray[].
|
org.nd4j.linalg.api.ndarray.INDArray |
outputSingle(org.nd4j.linalg.api.ndarray.INDArray... input)
A convenience method that returns a single INDArray, instead of an INDArray[].
|
org.nd4j.linalg.api.ndarray.INDArray |
params()
Parameters of the model (if any)
|
org.nd4j.linalg.api.ndarray.INDArray |
params(boolean backwardOnly)
Get the parameters for the ComputationGraph
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
paramTable()
The param table
|
void |
pretrain(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter)
Pretrain network with a single input and single output.
|
void |
pretrain(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iter)
Pretrain network with multiple inputs and/or outputs
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnActivateUsingStoredState(org.nd4j.linalg.api.ndarray.INDArray[] inputs,
boolean training,
boolean storeLastForTBPTT)
Similar to rnnTimeStep and feedForward() methods.
|
void |
rnnClearPreviousState()
Clear the previous state of the RNN layers (if any), used in
rnnTimeStep(INDArray...) |
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState(int layer)
Get the state of the RNN layer, as used in
rnnTimeStep(INDArray...) . |
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState(String layerName)
Get the state of the RNN layer, as used in
rnnTimeStep(INDArray...) . |
Map<String,Map<String,org.nd4j.linalg.api.ndarray.INDArray>> |
rnnGetPreviousStates()
Get a map of states for ALL RNN layers, as used in
rnnTimeStep(INDArray...) . |
void |
rnnSetPreviousState(int layer,
Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the state of the RNN layer, for use in
rnnTimeStep(INDArray...) |
void |
rnnSetPreviousState(String layerName,
Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the state of the RNN layer, for use in
rnnTimeStep(INDArray...) |
void |
rnnSetPreviousStates(Map<String,Map<String,org.nd4j.linalg.api.ndarray.INDArray>> previousStates)
Set the states for all RNN layers, for use in
rnnTimeStep(INDArray...) |
org.nd4j.linalg.api.ndarray.INDArray[] |
rnnTimeStep(org.nd4j.linalg.api.ndarray.INDArray... inputs)
If this ComputationGraph contains one or more RNN layers: conduct forward pass (prediction)
but using previous stored state for any RNN layers.
|
protected void |
rnnUpdateStateWithTBPTTState()
Update the internal state of RNN layers after a truncated BPTT fit call
|
double |
score()
The score for the model
|
double |
score(org.nd4j.linalg.dataset.api.DataSet dataSet)
Sets the input and labels and returns a score for the prediction with respect to the true labels
This is equivalent to score(DataSet, boolean) with training==true.NOTE: this version of the score function can only be used with ComputationGraph networks that have a single input and a single output. |
double |
score(org.nd4j.linalg.dataset.api.DataSet dataSet,
boolean training)
Sets the input and labels and returns a score for the prediction with respect to the true labels
NOTE: this version of the score function can only be used with ComputationGraph networks that have a single input and a single output. |
double |
score(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
Score the network given the MultiDataSet, at test time
|
double |
score(org.nd4j.linalg.dataset.api.MultiDataSet dataSet,
boolean training)
Sets the input and labels and returns a score for the prediction with respect to the true labels
|
org.nd4j.linalg.api.ndarray.INDArray |
scoreExamples(org.nd4j.linalg.dataset.api.DataSet data,
boolean addRegularizationTerms)
Calculate the score for each example in a DataSet individually.
|
org.nd4j.linalg.api.ndarray.INDArray |
scoreExamples(org.nd4j.linalg.dataset.api.MultiDataSet data,
boolean addRegularizationTerms)
Calculate the score for each example in a DataSet individually.
|
void |
setBackpropGradientsViewArray(org.nd4j.linalg.api.ndarray.INDArray gradient)
Set the gradients array as a view of the full (backprop) network parameters
NOTE: this is intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users.
|
void |
setConf(NeuralNetConfiguration conf)
Setter for the configuration
|
void |
setInput(int inputNum,
org.nd4j.linalg.api.ndarray.INDArray input)
Set the specified input for the ComputationGraph
|
void |
setInputs(org.nd4j.linalg.api.ndarray.INDArray... inputs)
Set all inputs for the ComputationGraph network
|
void |
setLabel(int labelNum,
org.nd4j.linalg.api.ndarray.INDArray label)
Set the specified label for the ComputationGraph
|
void |
setLabels(org.nd4j.linalg.api.ndarray.INDArray... labels)
Set all labels for the ComputationGraph network
|
void |
setLayerMaskArrays(org.nd4j.linalg.api.ndarray.INDArray[] featureMaskArrays,
org.nd4j.linalg.api.ndarray.INDArray[] labelMaskArrays)
Set the mask arrays for features and labels.
|
void |
setListeners(Collection<IterationListener> listeners)
Set the IterationListeners for the ComputationGraph (and all layers in the network)
|
void |
setListeners(IterationListener... listeners)
Set the IterationListeners for the ComputationGraph (and all layers in the network)
|
void |
setParam(String key,
org.nd4j.linalg.api.ndarray.INDArray val)
Set the parameter with a new ndarray
|
void |
setParams(org.nd4j.linalg.api.ndarray.INDArray params)
Set the parameters for this model.
|
void |
setParamsViewArray(org.nd4j.linalg.api.ndarray.INDArray gradient)
Set the initial parameters array as a view of the full (backprop) network parameters
NOTE: this is intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users.
|
void |
setParamTable(Map<String,org.nd4j.linalg.api.ndarray.INDArray> paramTable)
Setter for the param table
|
void |
setScore(double score) |
void |
setUpdater(ComputationGraphUpdater updater)
Set the computationGraphUpdater for the network
|
int[] |
topologicalSortOrder()
Calculate a topological sort order for the vertices in the graph.
|
void |
update(Gradient gradient)
Update layer weights and biases with gradient change
|
void |
update(org.nd4j.linalg.api.ndarray.INDArray gradient,
String paramType)
Perform one update applying the gradient
|
void |
validateInput()
Validate the input
|
protected ComputationGraphConfiguration configuration
protected boolean initCalled
protected transient Solver solver
protected org.nd4j.linalg.api.ndarray.INDArray flattenedParams
protected transient org.nd4j.linalg.api.ndarray.INDArray flattenedGradients
protected Gradient gradient
protected double score
protected GraphVertex[] vertices
protected Map<String,GraphVertex> verticesMap
protected int[] topologicalOrder
protected Layer[] layers
public ComputationGraph(ComputationGraphConfiguration configuration)
public ComputationGraphConfiguration getConfiguration()
public int getNumLayers()
public Layer getLayer(int idx)
public Layer[] getLayers()
public GraphVertex[] getVertices()
public GraphVertex getVertex(String name)
public int getNumInputArrays()
public int getNumOutputArrays()
public void setInput(int inputNum, org.nd4j.linalg.api.ndarray.INDArray input)
public void setInputs(org.nd4j.linalg.api.ndarray.INDArray... inputs)
public org.nd4j.linalg.api.ndarray.INDArray getInput(int inputNum)
public org.nd4j.linalg.api.ndarray.INDArray[] getInputs()
public org.nd4j.linalg.api.ndarray.INDArray[] getInputMaskArrays()
public org.nd4j.linalg.api.ndarray.INDArray[] getLabelMaskArrays()
public void setLabel(int labelNum, org.nd4j.linalg.api.ndarray.INDArray label)
public void setLabels(org.nd4j.linalg.api.ndarray.INDArray... labels)
public void init()
public void init(org.nd4j.linalg.api.ndarray.INDArray parameters, boolean cloneParametersArray)
parameters
- Network parameter. May be null. If null: randomly initialize.cloneParametersArray
- Whether the parameter array (if any) should be cloned, or used directlypublic void initGradientsView()
public void pretrain(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter)
pretrain(MultiDataSetIterator)
public void pretrain(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator iter)
public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet)
fit(MultiDataSetIterator)
public void fit(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iterator)
public void fit(org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet)
public void fit(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator multi)
public void fit(org.nd4j.linalg.api.ndarray.INDArray[] inputs, org.nd4j.linalg.api.ndarray.INDArray[] labels)
inputs
- The network inptuslabels
- The labelspublic void fit(org.nd4j.linalg.api.ndarray.INDArray[] inputs, org.nd4j.linalg.api.ndarray.INDArray[] labels, org.nd4j.linalg.api.ndarray.INDArray[] featureMaskArrays, org.nd4j.linalg.api.ndarray.INDArray[] labelMaskArrays)
inputs
- The network inputs (features)labels
- The network labelsfeatureMaskArrays
- Mask arrays for inputs/features. Typically used for RNN training. May be null.labelMaskArrays
- Mas arrays for the labels/outputs. Typically used for RNN training. May be null.public int[] topologicalSortOrder()
public void computeGradientAndScore()
Model
computeGradientAndScore
in interface Model
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> feedForward(org.nd4j.linalg.api.ndarray.INDArray input, boolean train)
input
- The input arraytrain
- If true: do forward pass at training timepublic Map<String,org.nd4j.linalg.api.ndarray.INDArray> feedForward(org.nd4j.linalg.api.ndarray.INDArray[] input, boolean train)
input
- An array of ComputationGraph inputstrain
- If true: do forward pass at training time; false: do forward pass at test timepublic Map<String,org.nd4j.linalg.api.ndarray.INDArray> feedForward()
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> feedForward(boolean train)
train
- If true: do forward pass at training time; false: do forward pass at test timepublic org.nd4j.linalg.api.ndarray.INDArray[] output(org.nd4j.linalg.api.ndarray.INDArray... input)
input
- Inputs to the networkpublic org.nd4j.linalg.api.ndarray.INDArray outputSingle(org.nd4j.linalg.api.ndarray.INDArray... input)
output(INDArray...)
input
- Inputs to the networkpublic org.nd4j.linalg.api.ndarray.INDArray[] output(boolean train, org.nd4j.linalg.api.ndarray.INDArray... input)
train
- If true: do forward pass at training time; false: do forward pass at test timeinput
- Inputs to the networkpublic org.nd4j.linalg.api.ndarray.INDArray outputSingle(boolean train, org.nd4j.linalg.api.ndarray.INDArray... input)
output(boolean, INDArray...)
train
- If true: do forward pass at training time; false: do forward pass at test timeinput
- Inputs to the networkpublic Gradient backpropGradient(org.nd4j.linalg.api.ndarray.INDArray... epsilons)
epsilons
- Epsilons (errors) at the output. Same order with which the output layers are defined in configuration setOutputs(String...)protected void calcBackpropGradients(boolean truncatedBPTT, org.nd4j.linalg.api.ndarray.INDArray... externalEpsilons)
truncatedBPTT
- false: normal backprop. true: calculate gradients using truncated BPTT for RNN layersexternalEpsilons
- null usually (for typical supervised learning). If not null (and length > 0) then assume that
the user has provided some errors externally, as they would do for example in reinforcement
learning situations.public ComputationGraph clone()
public double calcL2()
public double calcL1()
public void setListeners(Collection<IterationListener> listeners)
public void setListeners(IterationListener... listeners)
public Collection<IterationListener> getListeners()
public ComputationGraphUpdater getUpdater()
public void setUpdater(ComputationGraphUpdater updater)
public Layer getOutputLayer(int outputLayerIdx)
getNumOutputArrays()
-1public org.nd4j.linalg.api.ndarray.INDArray params(boolean backwardOnly)
backwardOnly
- If true: backprop parameters only (i.e., no visible layer biases used in layerwise pretraining layers)public double score(org.nd4j.linalg.dataset.api.DataSet dataSet)
score(DataSet, boolean)
with training==true.dataSet
- the data to scorescore(DataSet, boolean)
public double score(org.nd4j.linalg.dataset.api.DataSet dataSet, boolean training)
score(MultiDataSet, boolean)
for multiple input/output networksdataSet
- the data to scoretraining
- whether score is being calculated at training time (true) or test time (false)score(DataSet, boolean)
public double score(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
public double score(org.nd4j.linalg.dataset.api.MultiDataSet dataSet, boolean training)
dataSet
- the data to scoretraining
- whether score is being calculated at training time (true) or test time (false)public org.nd4j.linalg.api.ndarray.INDArray scoreExamples(org.nd4j.linalg.dataset.api.DataSet data, boolean addRegularizationTerms)
score(DataSet)
and score(DataSet, boolean)
this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which
may be useful for example for autoencoder architectures and the like.data
- The data to scoreaddRegularizationTerms
- If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization termspublic org.nd4j.linalg.api.ndarray.INDArray scoreExamples(org.nd4j.linalg.dataset.api.MultiDataSet data, boolean addRegularizationTerms)
score(MultiDataSet)
and score(MultiDataSet, boolean)
this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which
may be useful for example for autoencoder architectures and the like.data
- The data to scoreaddRegularizationTerms
- If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization termspublic void fit()
Model
public void update(org.nd4j.linalg.api.ndarray.INDArray gradient, String paramType)
Model
public void update(Gradient gradient)
Model
public double score()
Model
public void setScore(double score)
public void accumulateScore(double accum)
Model
accumulateScore
in interface Model
accum
- the amount to accumpublic org.nd4j.linalg.api.ndarray.INDArray params()
Model
public int numParams()
Model
public int numParams(boolean backwards)
Model
public void setParams(org.nd4j.linalg.api.ndarray.INDArray params)
Model
public void setParamsViewArray(org.nd4j.linalg.api.ndarray.INDArray gradient)
Model
setParamsViewArray
in interface Model
gradient
- a 1 x nParams row vector that is a view of the larger (MLN/CG) parameters arraypublic void setBackpropGradientsViewArray(org.nd4j.linalg.api.ndarray.INDArray gradient)
Model
setBackpropGradientsViewArray
in interface Model
gradient
- a 1 x nParams row vector that is a view of the larger (MLN/CG) gradients arraypublic void applyLearningRateScoreDecay()
Model
applyLearningRateScoreDecay
in interface Model
public void fit(org.nd4j.linalg.api.ndarray.INDArray data)
Model
public void iterate(org.nd4j.linalg.api.ndarray.INDArray input)
Model
public Pair<Gradient,Double> gradientAndScore()
Model
gradientAndScore
in interface Model
public int batchSize()
Model
public NeuralNetConfiguration conf()
Model
public void setConf(NeuralNetConfiguration conf)
Model
public org.nd4j.linalg.api.ndarray.INDArray input()
Model
public void validateInput()
Model
validateInput
in interface Model
public ConvexOptimizer getOptimizer()
Model
getOptimizer
in interface Model
public org.nd4j.linalg.api.ndarray.INDArray getParam(String param)
Model
public void initParams()
Model
initParams
in interface Model
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> paramTable()
Model
paramTable
in interface Model
public void setParamTable(Map<String,org.nd4j.linalg.api.ndarray.INDArray> paramTable)
Model
setParamTable
in interface Model
public void setParam(String key, org.nd4j.linalg.api.ndarray.INDArray val)
Model
public void clear()
Model
public org.nd4j.linalg.api.ndarray.INDArray[] rnnTimeStep(org.nd4j.linalg.api.ndarray.INDArray... inputs)
inputs
- Input to network. May be for one or multiple time steps. For single time step:
input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState(int layer)
rnnTimeStep(INDArray...)
.layer
- Number/index of the layer.public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState(String layerName)
rnnTimeStep(INDArray...)
.layerName
- name of the layerpublic Map<String,Map<String,org.nd4j.linalg.api.ndarray.INDArray>> rnnGetPreviousStates()
rnnTimeStep(INDArray...)
.
Layers that are not RNN layers will not have an entry in the returned maprnnSetPreviousStates(Map)
public void rnnSetPreviousState(int layer, Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
rnnTimeStep(INDArray...)
layer
- The number/index of the layer.state
- The state to set the specified layer topublic void rnnSetPreviousState(String layerName, Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
rnnTimeStep(INDArray...)
layerName
- The name of the layer.state
- The state to set the specified layer topublic void rnnSetPreviousStates(Map<String,Map<String,org.nd4j.linalg.api.ndarray.INDArray>> previousStates)
rnnTimeStep(INDArray...)
previousStates
- The previous time step states for all layers (key: layer name. Value: layer states)rnnGetPreviousStates()
public void rnnClearPreviousState()
rnnTimeStep(INDArray...)
protected void doTruncatedBPTT(org.nd4j.linalg.api.ndarray.INDArray[] inputs, org.nd4j.linalg.api.ndarray.INDArray[] labels, org.nd4j.linalg.api.ndarray.INDArray[] featureMasks, org.nd4j.linalg.api.ndarray.INDArray[] labelMasks)
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnActivateUsingStoredState(org.nd4j.linalg.api.ndarray.INDArray[] inputs, boolean training, boolean storeLastForTBPTT)
inputs
- Input to networktraining
- Whether training or notstoreLastForTBPTT
- set to true if used as part of truncated BPTT trainingpublic void setLayerMaskArrays(org.nd4j.linalg.api.ndarray.INDArray[] featureMaskArrays, org.nd4j.linalg.api.ndarray.INDArray[] labelMaskArrays)
featureMaskArrays
- Mask array for features (input)labelMaskArrays
- Mask array for labels (output)clearLayerMaskArrays()
public void clearLayerMaskArrays()
setLayerMaskArrays(INDArray[], INDArray[])
for details on mask arrays.protected void rnnUpdateStateWithTBPTTState()
Copyright © 2016. All Rights Reserved.