public abstract class BaseMultiLayerUpdater<T extends Model> extends Object implements Updater
This implements updater combining: that is, for any layers (and variables) that:
(a) have contiguous parameters/gradients in the view arrays, and
(b) have identical updater configuration (including updater, LR, LR/momentum schedules, etc - different L1/L2 are OK,
however)
are combined into a single GradientUpdater
operation, instead of having a set of
smaller operations. A smaller number of larger operations improves performance, especially for GPUs.
Modifier and Type | Field and Description |
---|---|
protected List<INDArray> |
gradientsForMinibatchDivision |
protected boolean |
initializedMinibatchDivision |
protected Map<String,Trainable> |
layersByName |
protected T |
network |
protected List<UpdaterBlock> |
updaterBlocks |
protected INDArray |
updaterStateViewArray |
Constructor and Description |
---|
BaseMultiLayerUpdater(T network) |
BaseMultiLayerUpdater(T network,
INDArray updaterState) |
Modifier and Type | Method and Description |
---|---|
protected void |
divideByMinibatch(boolean isExternal,
Gradient gradient,
int batchSize) |
boolean |
equals(Object o) |
protected abstract INDArray |
getFlattenedGradientsView() |
protected List<INDArray> |
getMinibatchDivisionSubsets(INDArray from) |
protected abstract Trainable[] |
getOrderedLayers() |
protected abstract INDArray |
getParams() |
INDArray |
getStateViewArray() |
INDArray |
getStateViewArrayCopy()
A synchronized version of
getStateViewArray() that duplicates the view array internally. |
int |
hashCode() |
protected abstract boolean |
isMiniBatch() |
protected boolean |
isSingleLayerUpdater() |
void |
preApply(Trainable layer,
Gradient gradient,
int iteration)
Pre-apply: Apply gradient normalization/clipping
|
void |
setStateViewArray(INDArray viewArray)
Set the view array.
|
void |
setStateViewArray(Trainable layer,
INDArray viewArray,
boolean initialize)
Set the internal (historical) state view array for this updater
|
void |
update(Gradient gradient,
int iteration,
int epoch,
int batchSize,
LayerWorkspaceMgr workspaceMgr)
Update the gradient for the model.
|
void |
update(Trainable layer,
Gradient gradient,
int iteration,
int epoch,
int batchSize,
LayerWorkspaceMgr workspaceMgr)
Updater: updates the model
|
protected final List<UpdaterBlock> updaterBlocks
protected INDArray updaterStateViewArray
protected boolean initializedMinibatchDivision
public BaseMultiLayerUpdater(T network)
protected abstract Trainable[] getOrderedLayers()
protected abstract INDArray getFlattenedGradientsView()
protected abstract INDArray getParams()
protected abstract boolean isMiniBatch()
public void setStateViewArray(INDArray viewArray)
viewArray
- The new updater statepublic void setStateViewArray(Trainable layer, INDArray viewArray, boolean initialize)
Updater
setStateViewArray
in interface Updater
layer
- Layer that this updater belongs toviewArray
- View arrayinitialize
- Whether to initialize the array or notpublic INDArray getStateViewArray()
getStateViewArray
in interface Updater
public INDArray getStateViewArrayCopy()
getStateViewArray()
that duplicates the view array internally.
This should be used in preference to getStateViewArray()
when the updater state is accessed in one
thread while another thread is using the updater for training.public void update(Trainable layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr)
Updater
public void update(Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr)
gradient
- Gradient to updateriteration
- The current iteration (i.e., number of parameter updates so far)batchSize
- The current minibatch size (number of examples)protected void divideByMinibatch(boolean isExternal, Gradient gradient, int batchSize)
protected boolean isSingleLayerUpdater()
public void preApply(Trainable layer, Gradient gradient, int iteration)
layer
- Layer to apply gradient normalization/clipping forgradient
- Gradient to updateiteration
- The current iteration (i.e., number of parameter updates so far)Copyright © 2019. All rights reserved.