public abstract class BaseRecurrentLayer<LayerConfT extends BaseRecurrentLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type
Modifier and Type | Field and Description |
---|---|
protected int |
helperCountFail |
protected Map<String,INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected Map<String,INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
gradient, gradientsFlattened, gradientViews, optimizer, params, paramsFlattened, score, solver, weightNoiseParams
cacheMode, conf, dataType, dropoutApplied, epochCount, index, input, inputModificationAllowed, iterationCount, maskArray, maskState, preOutput, trainingListeners
Constructor and Description |
---|
BaseRecurrentLayer(NeuralNetConfiguration conf,
DataType dataType) |
Modifier and Type | Method and Description |
---|---|
RNNFormat |
getDataFormat() |
protected INDArray |
permuteIfNWC(INDArray arr) |
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
Map<String,INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
Map<String,INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(Map<String,INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(Map<String,INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
activate, backpropGradient, calcRegularizationScore, clear, clearNoiseWeightParams, clone, computeGradientAndScore, fit, fit, getGradientsViewArray, getOptimizer, getParam, getParamWithNoise, gradient, hasBias, hasLayerNorm, layerConf, numParams, params, paramTable, paramTable, preOutput, preOutputWithPreNorm, score, setBackpropGradientsViewArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, update, update
activate, addListeners, allowInputModification, applyConstraints, applyDropOutIfNecessary, applyMask, assertInputSet, backpropDropOutIfPresent, batchSize, close, conf, feedForwardMaskArray, getConfig, getEpochCount, getHelper, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, gradientAndScore, init, input, layerId, numParams, setCacheMode, setConf, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, type, updaterDivideByMinibatch
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
rnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradient
activate, activate, allowInputModification, backpropGradient, calcRegularizationScore, clearNoiseWeightParams, feedForwardMaskArray, getEpochCount, getHelper, getIndex, getInputMiniBatchSize, getIterationCount, getListeners, getMaskArray, isPretrainLayer, setCacheMode, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setIterationCount, setListeners, setListeners, setMaskArray, type
addListeners, applyConstraints, batchSize, clear, close, computeGradientAndScore, conf, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, gradientAndScore, init, input, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, update
getConfig, getGradientsViewArray, numParams, params, paramTable, updaterDivideByMinibatch
protected Map<String,INDArray> stateMap
protected Map<String,INDArray> tBpttStateMap
protected int helperCountFail
public BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType)
public Map<String,INDArray> rnnGetPreviousState()
rnnGetPreviousState
in interface RecurrentLayer
public void rnnSetPreviousState(Map<String,INDArray> stateMap)
rnnSetPreviousState
in interface RecurrentLayer
public void rnnClearPreviousState()
rnnClearPreviousState
in interface RecurrentLayer
public Map<String,INDArray> rnnGetTBPTTState()
RecurrentLayer
rnnGetTBPTTState
in interface RecurrentLayer
public void rnnSetTBPTTState(Map<String,INDArray> state)
RecurrentLayer
rnnSetTBPTTState
in interface RecurrentLayer
state
- TBPTT state to setpublic RNNFormat getDataFormat()
Copyright © 2021. All rights reserved.