public abstract class BaseMultiLayerUpdater<T extends Model> extends Object implements Updater
This implements updater combining: that is, for any layers (and variables) that:
(a) have contiguous parameters/gradients in the view arrays, and
(b) have identical updater configuration (including updater, LR, LR/momentum schedules, etc - different L1/L2 are OK,
however)
are combined into a single GradientUpdater
operation, instead of having a set of
smaller operations. A smaller number of larger operations improves performance, especially for GPUs.
Modifier and Type | Field and Description |
---|---|
protected Map<String,Layer> |
layersByName |
protected T |
network |
protected List<UpdaterBlock> |
updaterBlocks |
protected org.nd4j.linalg.api.ndarray.INDArray |
updaterStateViewArray |
Constructor and Description |
---|
BaseMultiLayerUpdater(T network) |
BaseMultiLayerUpdater(T network,
org.nd4j.linalg.api.ndarray.INDArray updaterState) |
Modifier and Type | Method and Description |
---|---|
boolean |
equals(Object o) |
protected abstract org.nd4j.linalg.api.ndarray.INDArray |
getFlattenedGradientsView() |
protected abstract Layer[] |
getOrderedLayers() |
protected abstract org.nd4j.linalg.api.ndarray.INDArray |
getParams() |
org.nd4j.linalg.api.ndarray.INDArray |
getStateViewArray() |
int |
hashCode() |
protected abstract boolean |
isMiniBatch() |
protected boolean |
isSingleLayerUpdater() |
void |
preApply(Layer layer,
Gradient gradient,
int iteration)
Pre-apply: Apply gradient normalization/clipping
|
void |
setStateViewArray(org.nd4j.linalg.api.ndarray.INDArray viewArray)
Set the view array.
|
void |
setStateViewArray(Layer layer,
org.nd4j.linalg.api.ndarray.INDArray viewArray,
boolean initialize)
Set the internal (historical) state view array for this updater
|
void |
update(Gradient gradient,
int iteration,
int epoch,
int batchSize,
LayerWorkspaceMgr workspaceMgr)
Update the gradient for the model.
|
void |
update(Layer layer,
Gradient gradient,
int iteration,
int epoch,
int batchSize,
LayerWorkspaceMgr workspaceMgr)
Updater: updates the model
|
protected final List<UpdaterBlock> updaterBlocks
protected org.nd4j.linalg.api.ndarray.INDArray updaterStateViewArray
public BaseMultiLayerUpdater(T network)
public BaseMultiLayerUpdater(T network, org.nd4j.linalg.api.ndarray.INDArray updaterState)
network
- Network to create the updater forupdaterState
- The updater state to use. Note: This array is used *directly* and isn't copied/clonedprotected abstract Layer[] getOrderedLayers()
protected abstract org.nd4j.linalg.api.ndarray.INDArray getFlattenedGradientsView()
protected abstract org.nd4j.linalg.api.ndarray.INDArray getParams()
protected abstract boolean isMiniBatch()
public void setStateViewArray(org.nd4j.linalg.api.ndarray.INDArray viewArray)
viewArray
- The new updater statepublic void setStateViewArray(Layer layer, org.nd4j.linalg.api.ndarray.INDArray viewArray, boolean initialize)
Updater
setStateViewArray
in interface Updater
layer
- Layer that this updater belongs toviewArray
- View arrayinitialize
- Whether to initialize the array or notpublic org.nd4j.linalg.api.ndarray.INDArray getStateViewArray()
getStateViewArray
in interface Updater
public void update(Layer layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr)
Updater
public void update(Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr)
gradient
- Gradient to updateriteration
- The current iteration (i.e., number of parameter updates so far)batchSize
- The current minibatch size (number of examples)protected boolean isSingleLayerUpdater()
public void preApply(Layer layer, Gradient gradient, int iteration)
layer
- Layer to apply gradient normalization/clipping forgradient
- Gradient to updateiteration
- The current iteration (i.e., number of parameter updates so far)Copyright © 2018. All rights reserved.