public class UpdaterBlock extends Object
BaseMultiLayerUpdater
, this class implements updating (i.e., Adam, RMSProp, Momentum,
etc) across multiple contiguous layers/parameters, as described in the BaseMultiLayerUpdater
javadoc.Modifier and Type | Class and Description |
---|---|
static class |
UpdaterBlock.ParamState |
Constructor and Description |
---|
UpdaterBlock(int paramOffsetStart,
int paramOffsetEnd,
int updaterViewOffsetStart,
int updaterViewOffsetEnd,
List<UpdaterBlock.ParamState> layersAndVariablesInBlock) |
Modifier and Type | Method and Description |
---|---|
org.nd4j.linalg.learning.GradientUpdater |
getGradientUpdater() |
void |
init() |
boolean |
isPretrainUpdaterBlock() |
void |
postApply(Layer layer,
String paramName,
org.nd4j.linalg.api.ndarray.INDArray gradientView,
org.nd4j.linalg.api.ndarray.INDArray paramsView)
Apply L1 and L2 regularization, if necessary.
|
boolean |
skipDueToPretrainConfig() |
void |
update(int iteration,
int epoch)
Update the gradient for this block
|
void |
updateExternalGradient(int iteration,
int epoch,
org.nd4j.linalg.api.ndarray.INDArray fullNetworkGradientView,
org.nd4j.linalg.api.ndarray.INDArray fullNetworkParamsArray) |
public UpdaterBlock(int paramOffsetStart, int paramOffsetEnd, int updaterViewOffsetStart, int updaterViewOffsetEnd, List<UpdaterBlock.ParamState> layersAndVariablesInBlock)
paramOffsetStart
- Start offset of the parameters in this block (relative to overall net params
view array)paramOffsetEnd
- End offset of the parameters in this block (relative to overall net params
view array)updaterViewOffsetStart
- Start offset of the updater state array in this block (relative to overall net
updater state view array)updaterViewOffsetEnd
- End offset of the updater state array in this block (relative to overall net
updater state view array)layersAndVariablesInBlock
- List of layers and variables in this updater block. By definition, all layers
and variables in this list must have an identical updater configuration.public void init()
public boolean isPretrainUpdaterBlock()
public boolean skipDueToPretrainConfig()
public org.nd4j.linalg.learning.GradientUpdater getGradientUpdater()
public void update(int iteration, int epoch)
iteration
- The current iteration (i.e., total number of parameter updates so far)public void updateExternalGradient(int iteration, int epoch, org.nd4j.linalg.api.ndarray.INDArray fullNetworkGradientView, org.nd4j.linalg.api.ndarray.INDArray fullNetworkParamsArray)
public void postApply(Layer layer, String paramName, org.nd4j.linalg.api.ndarray.INDArray gradientView, org.nd4j.linalg.api.ndarray.INDArray paramsView)
layer
- The layer to apply L1/L2 toparamName
- Parameter name in the given layergradientView
- Gradient view array for the layer + paramparamsView
- Parameter view array for the layer + paramCopyright © 2018. All rights reserved.