public class BasicGradientsAccumulator extends Object implements GradientsAccumulator
Modifier and Type | Field and Description |
---|---|
protected CyclicBarrier |
barrier |
protected List<INDArray> |
candidates |
protected AtomicLong |
extCounter |
protected AtomicLong |
firstOne |
protected IndexedTail |
gradients |
protected MessageHandler |
handler |
protected AtomicBoolean |
hasSomething |
protected char |
ordering |
protected AtomicLong |
ownCounter |
protected int |
parties |
protected long[] |
shape |
protected INDArray |
storage |
protected INDArray |
updates |
protected ReentrantReadWriteLock |
updatesLock |
Constructor and Description |
---|
BasicGradientsAccumulator(int parties)
Creates new GradientsAccumulator with starting threshold of 1e-3
|
BasicGradientsAccumulator(int parties,
MessageHandler handler)
Creates new GradientsAccumulator with custom starting threshold
|
Modifier and Type | Method and Description |
---|---|
void |
applyUpdate(StepFunction function,
INDArray params,
INDArray grad,
boolean isFinalStep)
This method applies accumulated updates via given StepFunction
|
void |
applyUpdate(StepFunction function,
INDArray params,
INDArray grad,
double alpha)
This method applies accumulated updates via given StepFunction
|
IndexedTail |
getExternalSource() |
boolean |
hasAnything()
This method checks if there are any (probably external) updates available
|
void |
markExternalUpdates(boolean updatesAvailable)
This method allows to highlight early availability of updates
|
void |
receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
PLEASE NOTE: array is expected to be ready for use and match params dimensionality
|
void |
reset()
This method resets all accumulated updates (if any)
|
void |
setExternalSource(IndexedTail source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
|
void |
storeUpdate(INDArray array,
int iterationNumber,
int epochNumber)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
|
void |
touch()
This method does initialization of given worker wrt Thread-Device Affinity
|
protected MessageHandler handler
protected transient IndexedTail gradients
protected transient INDArray storage
protected transient INDArray updates
protected transient AtomicLong ownCounter
protected transient AtomicLong extCounter
protected long[] shape
protected char ordering
protected int parties
protected CyclicBarrier barrier
protected AtomicLong firstOne
protected ReentrantReadWriteLock updatesLock
protected AtomicBoolean hasSomething
public BasicGradientsAccumulator(int parties)
public BasicGradientsAccumulator(int parties, @NonNull MessageHandler handler)
handler
- MessageHandler instance that'll be used for communication purposespublic IndexedTail getExternalSource()
getExternalSource
in interface GradientsAccumulator
public void applyUpdate(StepFunction function, INDArray params, INDArray grad, boolean isFinalStep)
applyUpdate
in interface GradientsAccumulator
function
- params
- public void markExternalUpdates(boolean updatesAvailable)
GradientsAccumulator
markExternalUpdates
in interface GradientsAccumulator
public void applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha)
applyUpdate
in interface GradientsAccumulator
function
- params
- public void storeUpdate(INDArray array, int iterationNumber, int epochNumber)
storeUpdate
in interface GradientsAccumulator
array
- public void receiveUpdate(INDArray array)
receiveUpdate
in interface GradientsAccumulator
array
- public void reset()
reset
in interface GradientsAccumulator
public void touch()
touch
in interface GradientsAccumulator
public void setExternalSource(IndexedTail source)
GradientsAccumulator
setExternalSource
in interface GradientsAccumulator
public boolean hasAnything()
GradientsAccumulator
hasAnything
in interface GradientsAccumulator
Copyright © 2019. All rights reserved.