public abstract class BaseRecurrentLayer<LayerConfT extends BaseLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type| Modifier and Type | Field and Description |
|---|---|
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
gradient, gradientsFlattened, gradientViews, optimizer, params, paramsFlattened, score, solvercacheMode, conf, dropoutApplied, dropoutMask, index, input, iterationListeners, maskArray, maskState, preOutput| Constructor and Description |
|---|
BaseRecurrentLayer(NeuralNetConfiguration conf) |
BaseRecurrentLayer(NeuralNetConfiguration conf,
org.nd4j.linalg.api.ndarray.INDArray input) |
| Modifier and Type | Method and Description |
|---|---|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
accumulateScore, activate, activate, activate, activationMean, applyLearningRateScoreDecay, backpropGradient, calcGradient, calcL1, calcL2, clone, computeGradientAndScore, error, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, initParams, iterate, layerConf, merge, numParams, params, paramTable, paramTable, preOutput, preOutput, score, setBackpropGradientsViewArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, transpose, update, updateactivate, activate, activate, addListeners, applyDropOutIfNecessary, applyMask, batchSize, clear, conf, derivativeActivation, feedForwardMaskArray, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, gradientAndScore, init, input, layerId, numParams, preOutput, preOutput, setCacheMode, setConf, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, type, validateInputequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitrnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradientactivate, activate, activate, activate, activate, activate, activationMean, backpropGradient, calcGradient, calcL1, calcL2, clone, derivativeActivation, error, feedForwardMaskArray, getIndex, getInputMiniBatchSize, getListeners, getMaskArray, isPretrainLayer, merge, preOutput, preOutput, preOutput, setCacheMode, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, transpose, typeaccumulateScore, addListeners, applyLearningRateScoreDecay, batchSize, clear, computeGradientAndScore, conf, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, iterate, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInputprotected Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap
public BaseRecurrentLayer(NeuralNetConfiguration conf)
public BaseRecurrentLayer(NeuralNetConfiguration conf, org.nd4j.linalg.api.ndarray.INDArray input)
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
rnnGetPreviousState in interface RecurrentLayerpublic void rnnSetPreviousState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
rnnSetPreviousState in interface RecurrentLayerpublic void rnnClearPreviousState()
rnnClearPreviousState in interface RecurrentLayerpublic Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
RecurrentLayerrnnGetTBPTTState in interface RecurrentLayerpublic void rnnSetTBPTTState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
RecurrentLayerrnnSetTBPTTState in interface RecurrentLayerstate - TBPTT state to setCopyright © 2017. All rights reserved.