public class EncodedGradientsAccumulator extends Object implements GradientsAccumulator, Registerable
Modifier and Type | Class and Description |
---|---|
static class |
EncodedGradientsAccumulator.Builder |
Modifier and Type | Field and Description |
---|---|
protected ThreadLocal<org.nd4j.linalg.api.ndarray.INDArray> |
accumulator |
protected org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration |
appliedConfiguration |
protected AtomicInteger |
barrier |
protected Double |
boundary |
protected AtomicBoolean |
bypassMode |
protected AtomicInteger |
currentConsumers |
protected Queue<org.nd4j.linalg.api.ndarray.INDArray> |
externalSource |
protected MessageHandler |
handler |
protected ThreadLocal<Integer> |
index |
protected long |
initialMemory |
protected boolean |
isDebug |
protected AtomicBoolean |
isDone |
protected AtomicBoolean |
isFirst |
protected List<ReentrantLock> |
locks |
protected List<BlockingQueue<org.nd4j.linalg.api.ndarray.INDArray>> |
messages |
protected int |
parties |
protected int |
queueSize |
protected AtomicBoolean |
registered |
protected boolean |
relocatable |
protected AtomicInteger |
secondary |
protected org.nd4j.linalg.util.AtomicThrowable |
throwable |
protected AtomicInteger |
workersCounter |
protected List<org.nd4j.linalg.api.memory.MemoryWorkspace> |
workspaces |
Modifier | Constructor and Description |
---|---|
|
EncodedGradientsAccumulator(double parties) |
|
EncodedGradientsAccumulator(int parties) |
|
EncodedGradientsAccumulator(int parties,
double threshold) |
protected |
EncodedGradientsAccumulator(int parties,
MessageHandler handler,
long initialMemory,
int queueSize,
Double boundary) |
Modifier and Type | Method and Description |
---|---|
void |
applyUpdate(StepFunction function,
org.nd4j.linalg.api.ndarray.INDArray params,
org.nd4j.linalg.api.ndarray.INDArray updates)
This method applies accumulated updates via given StepFunction
|
void |
applyUpdate(StepFunction function,
org.nd4j.linalg.api.ndarray.INDArray params,
org.nd4j.linalg.api.ndarray.INDArray updates,
double alpha)
This method applies accumulated updates via given StepFunction
|
void |
fallbackToSingleConsumerMode(boolean reallyFallback)
This method enables/disables bypass mode
|
static int |
getOptimalBufferSize(int paramsLength,
int numWorkers,
int queueSize)
This method returns optimal bufferSize for a given model
We know, that updates are guaranteed to have MAX size of params / 16.
|
static int |
getOptimalBufferSize(Model model,
int numWorkers,
int queueSize) |
void |
receiveUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
|
void |
registerConsumers(int numConsumers)
This method notifies producer about number of consumers for the current consumption cycle
|
void |
reset()
This method resets all accumulated updates (if any)
|
void |
setExternalSource(Queue<org.nd4j.linalg.api.ndarray.INDArray> source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
|
void |
storeUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
|
protected void |
synchronize(int consumers) |
protected void |
synchronize(int consumers,
boolean finalLock) |
void |
touch()
This method does initialization of given worker wrt Thread-Device Affinity
|
protected ThreadLocal<org.nd4j.linalg.api.ndarray.INDArray> accumulator
protected int parties
protected MessageHandler handler
protected List<BlockingQueue<org.nd4j.linalg.api.ndarray.INDArray>> messages
protected List<org.nd4j.linalg.api.memory.MemoryWorkspace> workspaces
protected List<ReentrantLock> locks
protected AtomicInteger workersCounter
protected ThreadLocal<Integer> index
protected long initialMemory
protected int queueSize
protected Double boundary
protected Queue<org.nd4j.linalg.api.ndarray.INDArray> externalSource
protected AtomicBoolean isFirst
protected AtomicBoolean isDone
protected AtomicInteger barrier
protected AtomicInteger secondary
protected AtomicBoolean registered
protected AtomicBoolean bypassMode
protected final AtomicInteger currentConsumers
protected final org.nd4j.linalg.util.AtomicThrowable throwable
protected boolean isDebug
protected final boolean relocatable
protected org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration appliedConfiguration
public EncodedGradientsAccumulator(double parties)
public EncodedGradientsAccumulator(int parties)
public EncodedGradientsAccumulator(int parties, double threshold)
protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize, Double boundary)
public static int getOptimalBufferSize(int paramsLength, int numWorkers, int queueSize)
paramsLength
- numWorkers
- queueSize
- public static int getOptimalBufferSize(Model model, int numWorkers, int queueSize)
public void fallbackToSingleConsumerMode(boolean reallyFallback)
Registerable
fallbackToSingleConsumerMode
in interface Registerable
public void registerConsumers(int numConsumers)
Registerable
registerConsumers
in interface Registerable
protected void synchronize(int consumers)
protected void synchronize(int consumers, boolean finalLock)
public void applyUpdate(StepFunction function, org.nd4j.linalg.api.ndarray.INDArray params, org.nd4j.linalg.api.ndarray.INDArray updates)
applyUpdate
in interface GradientsAccumulator
function
- params
- public void applyUpdate(StepFunction function, org.nd4j.linalg.api.ndarray.INDArray params, org.nd4j.linalg.api.ndarray.INDArray updates, double alpha)
applyUpdate
in interface GradientsAccumulator
function
- params
- alpha
- public void setExternalSource(Queue<org.nd4j.linalg.api.ndarray.INDArray> source)
setExternalSource
in interface GradientsAccumulator
source
- public void touch()
touch
in interface GradientsAccumulator
public void storeUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
storeUpdate
in interface GradientsAccumulator
array
- public void receiveUpdate(org.nd4j.linalg.api.ndarray.INDArray array)
PLEASE NOTE: array is expected to be ready for use and match params dimensionality
receiveUpdate
in interface GradientsAccumulator
array
- public void reset()
reset
in interface GradientsAccumulator
Copyright © 2018. All rights reserved.