Class BaseRecurrentLayer<LayerConfT extends BaseRecurrentLayer>
- java.lang.Object
-
- org.deeplearning4j.nn.layers.AbstractLayer<LayerConfT>
-
- org.deeplearning4j.nn.layers.BaseLayer<LayerConfT>
-
- org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer<LayerConfT>
-
- All Implemented Interfaces:
Serializable
,Cloneable
,Layer
,RecurrentLayer
,Model
,Trainable
- Direct Known Subclasses:
GravesBidirectionalLSTM
,GravesLSTM
,LSTM
,SimpleRnn
public abstract class BaseRecurrentLayer<LayerConfT extends BaseRecurrentLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer
- See Also:
- Serialized Form
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface org.deeplearning4j.nn.api.Layer
Layer.TrainingMode, Layer.Type
-
-
Field Summary
Fields Modifier and Type Field Description protected int
helperCountFail
protected Map<String,INDArray>
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.protected Map<String,INDArray>
tBpttStateMap
State map for use specifically in truncated BPTT training.-
Fields inherited from class org.deeplearning4j.nn.layers.BaseLayer
gradient, gradientsFlattened, gradientViews, optimizer, params, paramsFlattened, score, solver, weightNoiseParams
-
Fields inherited from class org.deeplearning4j.nn.layers.AbstractLayer
cacheMode, conf, dataType, dropoutApplied, epochCount, index, input, inputModificationAllowed, iterationCount, maskArray, maskState, preOutput, trainingListeners
-
-
Constructor Summary
Constructors Constructor Description BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description RNNFormat
getDataFormat()
protected INDArray
permuteIfNWC(INDArray arr)
void
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()Map<String,INDArray>
rnnGetPreviousState()
Returns a shallow copy of the stateMapMap<String,INDArray>
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.void
rnnSetPreviousState(Map<String,INDArray> stateMap)
Set the state map.void
rnnSetTBPTTState(Map<String,INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.-
Methods inherited from class org.deeplearning4j.nn.layers.BaseLayer
activate, backpropGradient, calcRegularizationScore, clear, clearNoiseWeightParams, clone, computeGradientAndScore, fit, fit, getGradientsViewArray, getOptimizer, getParam, getParamWithNoise, gradient, hasBias, hasLayerNorm, layerConf, numParams, params, paramTable, paramTable, preOutput, preOutputWithPreNorm, score, setBackpropGradientsViewArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, update, update
-
Methods inherited from class org.deeplearning4j.nn.layers.AbstractLayer
activate, addListeners, allowInputModification, applyConstraints, applyDropOutIfNecessary, applyMask, assertInputSet, backpropDropOutIfPresent, batchSize, close, conf, feedForwardMaskArray, getConfig, getEpochCount, getHelper, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, gradientAndScore, init, input, layerId, numParams, setCacheMode, setConf, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, type, updaterDivideByMinibatch
-
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.nn.api.Layer
activate, activate, allowInputModification, backpropGradient, calcRegularizationScore, clearNoiseWeightParams, feedForwardMaskArray, getEpochCount, getHelper, getIndex, getInputMiniBatchSize, getIterationCount, getListeners, getMaskArray, isPretrainLayer, setCacheMode, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setIterationCount, setListeners, setListeners, setMaskArray, type
-
Methods inherited from interface org.deeplearning4j.nn.api.Model
addListeners, applyConstraints, batchSize, clear, close, computeGradientAndScore, conf, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, gradientAndScore, init, input, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update
-
Methods inherited from interface org.deeplearning4j.nn.api.layers.RecurrentLayer
rnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradient
-
Methods inherited from interface org.deeplearning4j.nn.api.Trainable
getConfig, getGradientsViewArray, numParams, params, paramTable, updaterDivideByMinibatch
-
-
-
-
Field Detail
-
stateMap
protected Map<String,INDArray> stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
-
tBpttStateMap
protected Map<String,INDArray> tBpttStateMap
State map for use specifically in truncated BPTT training. Whereas stateMap contains the state from which forward pass is initialized, the tBpttStateMap contains the state at the end of the last truncated bptt
-
helperCountFail
protected int helperCountFail
-
-
Constructor Detail
-
BaseRecurrentLayer
public BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType)
-
-
Method Detail
-
rnnGetPreviousState
public Map<String,INDArray> rnnGetPreviousState()
Returns a shallow copy of the stateMap- Specified by:
rnnGetPreviousState
in interfaceRecurrentLayer
-
rnnSetPreviousState
public void rnnSetPreviousState(Map<String,INDArray> stateMap)
Set the state map. Values set using this method will be used in next call to rnnTimeStep()- Specified by:
rnnSetPreviousState
in interfaceRecurrentLayer
-
rnnClearPreviousState
public void rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()- Specified by:
rnnClearPreviousState
in interfaceRecurrentLayer
-
rnnGetTBPTTState
public Map<String,INDArray> rnnGetTBPTTState()
Description copied from interface:RecurrentLayer
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer. The TBPTT state is used to store intermediate activations/state between updating parameters when doing TBPTT learning- Specified by:
rnnGetTBPTTState
in interfaceRecurrentLayer
- Returns:
- State for the RNN layer
-
rnnSetTBPTTState
public void rnnSetTBPTTState(Map<String,INDArray> state)
Description copied from interface:RecurrentLayer
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer. The TBPTT state is used to store intermediate activations/state between updating parameters when doing TBPTT learning- Specified by:
rnnSetTBPTTState
in interfaceRecurrentLayer
- Parameters:
state
- TBPTT state to set
-
getDataFormat
public RNNFormat getDataFormat()
-
-