public class BasicGradientsAccumulator extends Object implements GradientsAccumulator
Modifier and Type | Field and Description |
---|---|
protected CyclicBarrier |
barrier |
protected List<org.nd4j.linalg.api.ndarray.INDArray> |
candidates |
protected AtomicLong |
extCounter |
protected AtomicLong |
firstOne |
protected Queue<org.nd4j.linalg.api.ndarray.INDArray> |
gradients |
protected MessageHandler |
handler |
protected AtomicBoolean |
hasSomething |
protected char |
ordering |
protected AtomicLong |
ownCounter |
protected int |
parties |
protected int[] |
shape |
protected org.nd4j.linalg.api.ndarray.INDArray |
storage |
protected org.nd4j.linalg.api.ndarray.INDArray |
updates |
protected ReentrantReadWriteLock |
updatesLock |
Constructor and Description |
---|
BasicGradientsAccumulator(int parties)
Creates new GradientsAccumulator with starting threshold of 1e-3
|
BasicGradientsAccumulator(int parties,
MessageHandler handler)
Creates new GradientsAccumulator with custom starting threshold
|
Modifier and Type | Method and Description |
---|---|
void |
applyUpdate(StepFunction function,
org.nd4j.linalg.api.ndarray.INDArray params,
org.nd4j.linalg.api.ndarray.INDArray grad)
This method applies accumulated updates via given StepFunction
|
void |
applyUpdate(StepFunction function,
org.nd4j.linalg.api.ndarray.INDArray params,
org.nd4j.linalg.api.ndarray.INDArray grad,
double alpha)
This method applies accumulated updates via given StepFunction
|
void |
receiveUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
PLEASE NOTE: array is expected to be ready for use and match params dimensionality
|
void |
reset()
This method resets all accumulated updates (if any)
|
void |
setExternalSource(Queue<org.nd4j.linalg.api.ndarray.INDArray> source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
|
void |
storeUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
|
void |
touch()
This method does initialization of given worker wrt Thread-Device Affinity
|
protected MessageHandler handler
protected transient Queue<org.nd4j.linalg.api.ndarray.INDArray> gradients
protected transient org.nd4j.linalg.api.ndarray.INDArray storage
protected transient org.nd4j.linalg.api.ndarray.INDArray updates
protected transient AtomicLong ownCounter
protected transient AtomicLong extCounter
protected int[] shape
protected char ordering
protected int parties
protected CyclicBarrier barrier
protected AtomicLong firstOne
protected List<org.nd4j.linalg.api.ndarray.INDArray> candidates
protected ReentrantReadWriteLock updatesLock
protected AtomicBoolean hasSomething
public BasicGradientsAccumulator(int parties)
public BasicGradientsAccumulator(int parties, @NonNull MessageHandler handler)
handler
- MessageHandler instance that'll be used for communication purposespublic void applyUpdate(StepFunction function, org.nd4j.linalg.api.ndarray.INDArray params, org.nd4j.linalg.api.ndarray.INDArray grad)
applyUpdate
in interface GradientsAccumulator
function
- params
- public void applyUpdate(StepFunction function, org.nd4j.linalg.api.ndarray.INDArray params, org.nd4j.linalg.api.ndarray.INDArray grad, double alpha)
applyUpdate
in interface GradientsAccumulator
function
- params
- public void storeUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
storeUpdate
in interface GradientsAccumulator
array
- public void receiveUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
receiveUpdate
in interface GradientsAccumulator
array
- public void reset()
reset
in interface GradientsAccumulator
public void touch()
touch
in interface GradientsAccumulator
public void setExternalSource(Queue<org.nd4j.linalg.api.ndarray.INDArray> source)
GradientsAccumulator
setExternalSource
in interface GradientsAccumulator
Copyright © 2017. All rights reserved.