public abstract class BaseRecurrentLayer<LayerConfT extends BaseLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type| Modifier and Type | Field and Description |
|---|---|
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
gradient, gradientsFlattened, gradientViews, optimizer, params, paramsFlattened, score, solver, weightNoiseParamscacheMode, conf, dropoutApplied, dropoutMask, epochCount, index, input, iterationCount, maskArray, maskState, preOutput, trainingListeners| Constructor and Description |
|---|
BaseRecurrentLayer(NeuralNetConfiguration conf) |
BaseRecurrentLayer(NeuralNetConfiguration conf,
org.nd4j.linalg.api.ndarray.INDArray input) |
| Modifier and Type | Method and Description |
|---|---|
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
Map<String,org.nd4j.linalg.api.ndarray.INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
accumulateScore, activate, backpropGradient, calcL1, calcL2, clear, clearNoiseWeightParams, clone, computeGradientAndScore, fit, fit, getGradientsViewArray, getOptimizer, getParam, getParamWithNoise, gradient, hasBias, initParams, layerConf, numParams, params, paramTable, paramTable, preOutput, score, setBackpropGradientsViewArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, transpose, update, updateactivate, addListeners, applyConstraints, applyDropOutIfNecessary, applyMask, assertInputSet, batchSize, conf, feedForwardMaskArray, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, gradientAndScore, init, input, layerId, numParams, setCacheMode, setConf, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, type, validateInputequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitrnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradientactivate, activate, backpropGradient, calcL1, calcL2, clearNoiseWeightParams, clone, feedForwardMaskArray, getEpochCount, getIndex, getInputMiniBatchSize, getIterationCount, getListeners, getMaskArray, isPretrainLayer, setCacheMode, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setIterationCount, setListeners, setListeners, setMaskArray, transpose, typeaccumulateScore, addListeners, applyConstraints, batchSize, clear, computeGradientAndScore, conf, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, gradientAndScore, init, initParams, input, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update, validateInputprotected Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap
public BaseRecurrentLayer(NeuralNetConfiguration conf)
public BaseRecurrentLayer(NeuralNetConfiguration conf, org.nd4j.linalg.api.ndarray.INDArray input)
public Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetPreviousState()
rnnGetPreviousState in interface RecurrentLayerpublic void rnnSetPreviousState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> stateMap)
rnnSetPreviousState in interface RecurrentLayerpublic void rnnClearPreviousState()
rnnClearPreviousState in interface RecurrentLayerpublic Map<String,org.nd4j.linalg.api.ndarray.INDArray> rnnGetTBPTTState()
RecurrentLayerrnnGetTBPTTState in interface RecurrentLayerpublic void rnnSetTBPTTState(Map<String,org.nd4j.linalg.api.ndarray.INDArray> state)
RecurrentLayerrnnSetTBPTTState in interface RecurrentLayerstate - TBPTT state to setCopyright © 2018. All rights reserved.