public class BasicGradientsAccumulator extends Object implements GradientsAccumulator
| Modifier and Type | Field and Description |
|---|---|
protected CyclicBarrier |
barrier |
protected List<INDArray> |
candidates |
protected AtomicLong |
extCounter |
protected AtomicLong |
firstOne |
protected Queue<INDArray> |
gradients |
protected MessageHandler |
handler |
protected AtomicBoolean |
hasSomething |
protected char |
ordering |
protected AtomicLong |
ownCounter |
protected int |
parties |
protected long[] |
shape |
protected INDArray |
storage |
protected INDArray |
updates |
protected ReentrantReadWriteLock |
updatesLock |
| Constructor and Description |
|---|
BasicGradientsAccumulator(int parties)
Creates new GradientsAccumulator with starting threshold of 1e-3
|
BasicGradientsAccumulator(int parties,
MessageHandler handler)
Creates new GradientsAccumulator with custom starting threshold
|
| Modifier and Type | Method and Description |
|---|---|
void |
applyUpdate(StepFunction function,
INDArray params,
INDArray grad)
This method applies accumulated updates via given StepFunction
|
void |
applyUpdate(StepFunction function,
INDArray params,
INDArray grad,
double alpha)
This method applies accumulated updates via given StepFunction
|
void |
receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
PLEASE NOTE: array is expected to be ready for use and match params dimensionality
|
void |
reset()
This method resets all accumulated updates (if any)
|
void |
setExternalSource(Queue<INDArray> source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
|
void |
storeUpdate(INDArray array)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
|
void |
touch()
This method does initialization of given worker wrt Thread-Device Affinity
|
protected MessageHandler handler
protected transient INDArray storage
protected transient INDArray updates
protected transient AtomicLong ownCounter
protected transient AtomicLong extCounter
protected long[] shape
protected char ordering
protected int parties
protected CyclicBarrier barrier
protected AtomicLong firstOne
protected ReentrantReadWriteLock updatesLock
protected AtomicBoolean hasSomething
public BasicGradientsAccumulator(int parties)
public BasicGradientsAccumulator(int parties,
@NonNull
MessageHandler handler)
handler - MessageHandler instance that'll be used for communication purposespublic void applyUpdate(StepFunction function, INDArray params, INDArray grad)
applyUpdate in interface GradientsAccumulatorfunction - params - public void applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha)
applyUpdate in interface GradientsAccumulatorfunction - params - public void storeUpdate(INDArray array)
storeUpdate in interface GradientsAccumulatorarray - public void receiveUpdate(INDArray array)
receiveUpdate in interface GradientsAccumulatorarray - public void reset()
reset in interface GradientsAccumulatorpublic void touch()
touch in interface GradientsAccumulatorpublic void setExternalSource(Queue<INDArray> source)
GradientsAccumulatorsetExternalSource in interface GradientsAccumulatorCopyright © 2018. All rights reserved.