Class BaseRecurrentLayer<LayerConfT extends BaseRecurrentLayer>
- java.lang.Object
-
- org.deeplearning4j.nn.layers.AbstractLayer<LayerConfT>
-
- org.deeplearning4j.nn.layers.BaseLayer<LayerConfT>
-
- org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer<LayerConfT>
-
- All Implemented Interfaces:
Serializable,Cloneable,Layer,RecurrentLayer,Model,Trainable
- Direct Known Subclasses:
GravesBidirectionalLSTM,GravesLSTM,LSTM,SimpleRnn
public abstract class BaseRecurrentLayer<LayerConfT extends BaseRecurrentLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer
- See Also:
- Serialized Form
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface org.deeplearning4j.nn.api.Layer
Layer.TrainingMode, Layer.Type
-
-
Field Summary
Fields Modifier and Type Field Description protected inthelperCountFailprotected Map<String,INDArray>stateMapstateMap stores the INDArrays needed to do rnnTimeStep() forward pass.protected Map<String,INDArray>tBpttStateMapState map for use specifically in truncated BPTT training.-
Fields inherited from class org.deeplearning4j.nn.layers.BaseLayer
gradient, gradientsFlattened, gradientViews, optimizer, params, paramsFlattened, score, solver, weightNoiseParams
-
Fields inherited from class org.deeplearning4j.nn.layers.AbstractLayer
cacheMode, conf, dataType, dropoutApplied, epochCount, index, input, inputModificationAllowed, iterationCount, maskArray, maskState, preOutput, trainingListeners
-
-
Constructor Summary
Constructors Constructor Description BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description RNNFormatgetDataFormat()protected INDArraypermuteIfNWC(INDArray arr)voidrnnClearPreviousState()Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()Map<String,INDArray>rnnGetPreviousState()Returns a shallow copy of the stateMapMap<String,INDArray>rnnGetTBPTTState()Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.voidrnnSetPreviousState(Map<String,INDArray> stateMap)Set the state map.voidrnnSetTBPTTState(Map<String,INDArray> state)Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.-
Methods inherited from class org.deeplearning4j.nn.layers.BaseLayer
activate, backpropGradient, calcRegularizationScore, clear, clearNoiseWeightParams, clone, computeGradientAndScore, fit, fit, getGradientsViewArray, getOptimizer, getParam, getParamWithNoise, gradient, hasBias, hasLayerNorm, layerConf, numParams, params, paramTable, paramTable, preOutput, preOutputWithPreNorm, score, setBackpropGradientsViewArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, update, update
-
Methods inherited from class org.deeplearning4j.nn.layers.AbstractLayer
activate, addListeners, allowInputModification, applyConstraints, applyDropOutIfNecessary, applyMask, assertInputSet, backpropDropOutIfPresent, batchSize, close, conf, feedForwardMaskArray, getConfig, getEpochCount, getHelper, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, gradientAndScore, init, input, layerId, numParams, setCacheMode, setConf, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, type, updaterDivideByMinibatch
-
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.nn.api.Layer
activate, activate, allowInputModification, backpropGradient, calcRegularizationScore, clearNoiseWeightParams, feedForwardMaskArray, getEpochCount, getHelper, getIndex, getInputMiniBatchSize, getIterationCount, getListeners, getMaskArray, isPretrainLayer, setCacheMode, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setIterationCount, setListeners, setListeners, setMaskArray, type
-
Methods inherited from interface org.deeplearning4j.nn.api.Model
addListeners, applyConstraints, batchSize, clear, close, computeGradientAndScore, conf, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, gradientAndScore, init, input, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update
-
Methods inherited from interface org.deeplearning4j.nn.api.layers.RecurrentLayer
rnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradient
-
Methods inherited from interface org.deeplearning4j.nn.api.Trainable
getConfig, getGradientsViewArray, numParams, params, paramTable, updaterDivideByMinibatch
-
-
-
-
Field Detail
-
stateMap
protected Map<String,INDArray> stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
-
tBpttStateMap
protected Map<String,INDArray> tBpttStateMap
State map for use specifically in truncated BPTT training. Whereas stateMap contains the state from which forward pass is initialized, the tBpttStateMap contains the state at the end of the last truncated bptt
-
helperCountFail
protected int helperCountFail
-
-
Constructor Detail
-
BaseRecurrentLayer
public BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType)
-
-
Method Detail
-
rnnGetPreviousState
public Map<String,INDArray> rnnGetPreviousState()
Returns a shallow copy of the stateMap- Specified by:
rnnGetPreviousStatein interfaceRecurrentLayer
-
rnnSetPreviousState
public void rnnSetPreviousState(Map<String,INDArray> stateMap)
Set the state map. Values set using this method will be used in next call to rnnTimeStep()- Specified by:
rnnSetPreviousStatein interfaceRecurrentLayer
-
rnnClearPreviousState
public void rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()- Specified by:
rnnClearPreviousStatein interfaceRecurrentLayer
-
rnnGetTBPTTState
public Map<String,INDArray> rnnGetTBPTTState()
Description copied from interface:RecurrentLayerGet the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer. The TBPTT state is used to store intermediate activations/state between updating parameters when doing TBPTT learning- Specified by:
rnnGetTBPTTStatein interfaceRecurrentLayer- Returns:
- State for the RNN layer
-
rnnSetTBPTTState
public void rnnSetTBPTTState(Map<String,INDArray> state)
Description copied from interface:RecurrentLayerSet the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer. The TBPTT state is used to store intermediate activations/state between updating parameters when doing TBPTT learning- Specified by:
rnnSetTBPTTStatein interfaceRecurrentLayer- Parameters:
state- TBPTT state to set
-
getDataFormat
public RNNFormat getDataFormat()
-
-