Package org.deeplearning4j.parallelism
Class ParallelWrapper
- java.lang.Object
-
- org.deeplearning4j.parallelism.ParallelWrapper
-
- All Implemented Interfaces:
AutoCloseable
public class ParallelWrapper extends Object implements AutoCloseable
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model>static classParallelWrapper.TrainingMode
-
Field Summary
Fields Modifier and Type Field Description protected booleanaverageUpdatersprotected intaveragingFrequencyprotected booleandebugprotected Throwableexceptionprotected AtomicBooleanexceptionEncounteredprotected ThreadPoolExecutorexecutorServiceprotected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulatorgradientsAccumulatorprotected booleanisMQprotected AtomicLongiterationsCounterprotected booleanlegacyAveragingprotected List<org.deeplearning4j.optimize.api.TrainingListener>listenersprotected org.deeplearning4j.nn.api.Modelmodelprotected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>modelParamsSupplierprotected intprefetchSizeprotected booleanreportScoreprotected AtomicBooleanstopFitprotected StatsStorageRouterstorageRouterprotected TrainerContexttrainerContextprotected Object[]trainerContextArgsprotected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>updaterParamsSupplierprotected Stringuuidprotected booleanwasAveragedprotected AtomicIntegerworkerCounterprotected intworkersprotected org.deeplearning4j.nn.conf.WorkspaceModeworkspaceModeprotected Trainer[]zoo
-
Constructor Summary
Constructors Modifier Constructor Description protectedParallelWrapper(org.deeplearning4j.nn.api.Model model, int workers, int prefetchSize)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidbroadcastGradients(org.deeplearning4j.optimize.listeners.SharedGradient gradients)This method will propagate gradients across all workersvoidclose()voidfit(@NonNull org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executorsvoidfit(@NonNull org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source)protected voidinit()voidsetListeners(@NonNull Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners)This method allows you to specify trainingListeners for this model.voidsetListeners(@NonNull org.deeplearning4j.optimize.api.TrainingListener... listeners)This method allows you to specify trainingListeners for this model.voidsetListeners(StatsStorageRouter statsStorage, Collection<? extends org.deeplearning4j.optimize.api.TrainingListener> listeners)Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListenerinterface)voidsetListeners(StatsStorageRouter statsStorage, org.deeplearning4j.optimize.api.TrainingListener... listeners)Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListenerinterface)voidshutdown()This method causes all threads used for parallel training to stopvoidstopFit()Will stop a fit operation from continuing to iterate.
-
-
-
Field Detail
-
modelParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> modelParamsSupplier
-
updaterParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> updaterParamsSupplier
-
exceptionEncountered
protected AtomicBoolean exceptionEncountered
-
exception
protected Throwable exception
-
uuid
protected final String uuid
-
model
protected org.deeplearning4j.nn.api.Model model
-
workers
protected int workers
-
prefetchSize
protected int prefetchSize
-
averagingFrequency
protected int averagingFrequency
-
zoo
protected Trainer[] zoo
-
trainerContext
protected TrainerContext trainerContext
-
iterationsCounter
protected AtomicLong iterationsCounter
-
reportScore
protected boolean reportScore
-
averageUpdaters
protected boolean averageUpdaters
-
legacyAveraging
protected boolean legacyAveraging
-
wasAveraged
protected boolean wasAveraged
-
stopFit
protected AtomicBoolean stopFit
-
listeners
protected List<org.deeplearning4j.optimize.api.TrainingListener> listeners
-
storageRouter
protected StatsStorageRouter storageRouter
-
isMQ
protected boolean isMQ
-
workspaceMode
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
-
trainerContextArgs
protected Object[] trainerContextArgs
-
debug
protected boolean debug
-
executorService
protected ThreadPoolExecutor executorService
-
workerCounter
protected final AtomicInteger workerCounter
-
gradientsAccumulator
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator gradientsAccumulator
-
-
Method Detail
-
init
protected void init()
-
close
public void close() throws Exception- Specified by:
closein interfaceAutoCloseable- Throws:
Exception
-
shutdown
public void shutdown()
This method causes all threads used for parallel training to stop
-
stopFit
public void stopFit()
Will stop a fit operation from continuing to iterate.
-
fit
public void fit(@NonNull @NonNull org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source)- Parameters:
source-
-
setListeners
public void setListeners(@NonNull @NonNull Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners)This method allows you to specify trainingListeners for this model. Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead usingsetListeners(StatsStorageRouter, Collection)- Parameters:
listeners- Listeners to set
-
setListeners
public void setListeners(@NonNull @NonNull org.deeplearning4j.optimize.api.TrainingListener... listeners)This method allows you to specify trainingListeners for this model. Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead usingsetListeners(StatsStorageRouter, Collection)- Parameters:
listeners- Listeners to set
-
setListeners
public void setListeners(StatsStorageRouter statsStorage, org.deeplearning4j.optimize.api.TrainingListener... listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListenerinterface)- Parameters:
statsStorage- Stats storage router to place the results intolisteners- Listeners to set
-
setListeners
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends org.deeplearning4j.optimize.api.TrainingListener> listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListenerinterface)- Parameters:
statsStorage- Stats storage router to place the results intolisteners- Listeners to set
-
broadcastGradients
public void broadcastGradients(org.deeplearning4j.optimize.listeners.SharedGradient gradients)
This method will propagate gradients across all workers- Parameters:
gradients-
-
fit
public void fit(@NonNull @NonNull org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors- Parameters:
source-
-
-