public abstract class BaseMultiLayerUpdater<T extends Model> extends Object implements Updater
Modifier and Type | Field and Description |
---|---|
protected List<INDArray> |
gradientsForMinibatchDivision |
protected boolean |
initializedMinibatchDivision |
protected Map<String,Trainable> |
layersByName |
protected T |
network |
protected List<UpdaterBlock> |
updaterBlocks |
protected INDArray |
updaterStateViewArray |
Constructor and Description |
---|
BaseMultiLayerUpdater(T network) |
BaseMultiLayerUpdater(T network,
INDArray updaterState) |
Modifier and Type | Method and Description |
---|---|
protected void |
divideByMinibatch(boolean isExternal,
Gradient gradient,
int batchSize) |
boolean |
equals(Object o) |
protected abstract INDArray |
getFlattenedGradientsView() |
protected List<INDArray> |
getMinibatchDivisionSubsets(INDArray from) |
protected abstract Trainable[] |
getOrderedLayers() |
protected abstract INDArray |
getParams() |
INDArray |
getStateViewArray() |
INDArray |
getStateViewArrayCopy()
A synchronized version of
getStateViewArray() that duplicates the view array internally. |
int |
hashCode() |
protected abstract boolean |
isMiniBatch() |
protected boolean |
isSingleLayerUpdater() |
void |
preApply(Trainable layer,
Gradient gradient,
int iteration)
Pre-apply: Apply gradient normalization/clipping
|
void |
setStateViewArray(INDArray viewArray)
Set the view array.
|
void |
setStateViewArray(Trainable layer,
INDArray viewArray,
boolean initialize)
Set the internal (historical) state view array for this updater
|
void |
update(Gradient gradient,
int iteration,
int epoch,
int batchSize,
LayerWorkspaceMgr workspaceMgr)
Update the gradient for the model.
|
void |
update(Trainable layer,
Gradient gradient,
int iteration,
int epoch,
int batchSize,
LayerWorkspaceMgr workspaceMgr)
Updater: updates the model
|
protected final List<UpdaterBlock> updaterBlocks
protected INDArray updaterStateViewArray
protected boolean initializedMinibatchDivision
public BaseMultiLayerUpdater(T network)
protected abstract Trainable[] getOrderedLayers()
protected abstract INDArray getFlattenedGradientsView()
protected abstract INDArray getParams()
protected abstract boolean isMiniBatch()
public void setStateViewArray(INDArray viewArray)
viewArray
- The new updater statepublic void setStateViewArray(Trainable layer, INDArray viewArray, boolean initialize)
Updater
setStateViewArray
in interface Updater
layer
- Layer that this updater belongs toviewArray
- View arrayinitialize
- Whether to initialize the array or notpublic INDArray getStateViewArray()
getStateViewArray
in interface Updater
public INDArray getStateViewArrayCopy()
getStateViewArray()
that duplicates the view array internally.
This should be used in preference to getStateViewArray()
when the updater state is accessed in one
thread while another thread is using the updater for training.public void update(Trainable layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr)
Updater
public void update(Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr)
gradient
- Gradient to updateriteration
- The current iteration (i.e., number of parameter updates so far)batchSize
- The current minibatch size (number of examples)protected void divideByMinibatch(boolean isExternal, Gradient gradient, int batchSize)
protected boolean isSingleLayerUpdater()
public void preApply(Trainable layer, Gradient gradient, int iteration)
layer
- Layer to apply gradient normalization/clipping forgradient
- Gradient to updateiteration
- The current iteration (i.e., number of parameter updates so far)Copyright © 2021. All rights reserved.