public abstract class BaseRecurrentLayer<LayerConfT extends Layer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type
Modifier and Type | Field and Description |
---|---|
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
conf, dropoutApplied, dropoutMask, gradient, gradientsFlattened, gradientViews, index, input, iterationListeners, maskArray, optimizer, params, paramsFlattened, score, solver
Constructor and Description |
---|
BaseRecurrentLayer(NeuralNetConfiguration conf) |
BaseRecurrentLayer(NeuralNetConfiguration conf,
org.nd4j.linalg.api.ndarray.INDArray input) |
Modifier and Type | Method and Description |
---|---|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
accumulateScore, activate, activate, activate, activate, activate, activate, activationMean, applyDropOutIfNecessary, applyLearningRateScoreDecay, backpropGradient, batchSize, calcGradient, calcL1, calcL2, clear, clone, computeGradientAndScore, conf, createGradient, derivativeActivation, error, fit, fit, getIndex, getInput, getInputMiniBatchSize, getListeners, getOptimizer, getParam, gradient, gradientAndScore, initParams, input, iterate, layerConf, merge, numParams, numParams, params, paramTable, preOutput, preOutput, preOutput, preOutput, score, setBackpropGradientsViewArray, setConf, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, transpose, type, update, update, validateInput
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
rnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradient
activate, activate, activate, activate, activate, activate, activationMean, backpropGradient, calcGradient, calcL1, calcL2, clone, derivativeActivation, error, getIndex, getInputMiniBatchSize, getListeners, merge, preOutput, preOutput, preOutput, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, transpose, type
accumulateScore, applyLearningRateScoreDecay, batchSize, clear, computeGradientAndScore, conf, fit, fit, getOptimizer, getParam, gradient, gradientAndScore, initParams, input, iterate, numParams, numParams, params, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInput
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap
public BaseRecurrentLayer(NeuralNetConfiguration conf)
public BaseRecurrentLayer(NeuralNetConfiguration conf, org.nd4j.linalg.api.ndarray.INDArray input)
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
rnnGetPreviousState
in interface RecurrentLayer
public void rnnSetPreviousState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
rnnSetPreviousState
in interface RecurrentLayer
public void rnnClearPreviousState()
rnnClearPreviousState
in interface RecurrentLayer
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
RecurrentLayer
rnnGetTBPTTState
in interface RecurrentLayer
public void rnnSetTBPTTState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
RecurrentLayer
rnnSetTBPTTState
in interface RecurrentLayer
state
- TBPTT state to setCopyright © 2016. All Rights Reserved.