Package org.deeplearning4j.parallelism
Class ParallelWrapper
- java.lang.Object
-
- org.deeplearning4j.parallelism.ParallelWrapper
-
- All Implemented Interfaces:
AutoCloseable
public class ParallelWrapper extends Object implements AutoCloseable
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model>
static class
ParallelWrapper.TrainingMode
-
Field Summary
Fields Modifier and Type Field Description protected boolean
averageUpdaters
protected int
averagingFrequency
protected boolean
debug
protected Throwable
exception
protected AtomicBoolean
exceptionEncountered
protected ThreadPoolExecutor
executorService
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
gradientsAccumulator
protected boolean
isMQ
protected AtomicLong
iterationsCounter
protected boolean
legacyAveraging
protected List<org.deeplearning4j.optimize.api.TrainingListener>
listeners
protected org.deeplearning4j.nn.api.Model
model
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>
modelParamsSupplier
protected int
prefetchSize
protected boolean
reportScore
protected AtomicBoolean
stopFit
protected StatsStorageRouter
storageRouter
protected TrainerContext
trainerContext
protected Object[]
trainerContextArgs
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>
updaterParamsSupplier
protected String
uuid
protected boolean
wasAveraged
protected AtomicInteger
workerCounter
protected int
workers
protected org.deeplearning4j.nn.conf.WorkspaceMode
workspaceMode
protected Trainer[]
zoo
-
Constructor Summary
Constructors Modifier Constructor Description protected
ParallelWrapper(org.deeplearning4j.nn.api.Model model, int workers, int prefetchSize)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
broadcastGradients(org.deeplearning4j.optimize.listeners.SharedGradient gradients)
This method will propagate gradients across all workersvoid
close()
void
fit(@NonNull org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)
This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executorsvoid
fit(@NonNull org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source)
protected void
init()
void
setListeners(@NonNull Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners)
This method allows you to specify trainingListeners for this model.void
setListeners(@NonNull org.deeplearning4j.optimize.api.TrainingListener... listeners)
This method allows you to specify trainingListeners for this model.void
setListeners(StatsStorageRouter statsStorage, Collection<? extends org.deeplearning4j.optimize.api.TrainingListener> listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListener
interface)void
setListeners(StatsStorageRouter statsStorage, org.deeplearning4j.optimize.api.TrainingListener... listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListener
interface)void
shutdown()
This method causes all threads used for parallel training to stopvoid
stopFit()
Will stop a fit operation from continuing to iterate.
-
-
-
Field Detail
-
modelParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> modelParamsSupplier
-
updaterParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> updaterParamsSupplier
-
exceptionEncountered
protected AtomicBoolean exceptionEncountered
-
exception
protected Throwable exception
-
uuid
protected final String uuid
-
model
protected org.deeplearning4j.nn.api.Model model
-
workers
protected int workers
-
prefetchSize
protected int prefetchSize
-
averagingFrequency
protected int averagingFrequency
-
zoo
protected Trainer[] zoo
-
trainerContext
protected TrainerContext trainerContext
-
iterationsCounter
protected AtomicLong iterationsCounter
-
reportScore
protected boolean reportScore
-
averageUpdaters
protected boolean averageUpdaters
-
legacyAveraging
protected boolean legacyAveraging
-
wasAveraged
protected boolean wasAveraged
-
stopFit
protected AtomicBoolean stopFit
-
listeners
protected List<org.deeplearning4j.optimize.api.TrainingListener> listeners
-
storageRouter
protected StatsStorageRouter storageRouter
-
isMQ
protected boolean isMQ
-
workspaceMode
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
-
trainerContextArgs
protected Object[] trainerContextArgs
-
debug
protected boolean debug
-
executorService
protected ThreadPoolExecutor executorService
-
workerCounter
protected final AtomicInteger workerCounter
-
gradientsAccumulator
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator gradientsAccumulator
-
-
Method Detail
-
init
protected void init()
-
close
public void close() throws Exception
- Specified by:
close
in interfaceAutoCloseable
- Throws:
Exception
-
shutdown
public void shutdown()
This method causes all threads used for parallel training to stop
-
stopFit
public void stopFit()
Will stop a fit operation from continuing to iterate.
-
fit
public void fit(@NonNull @NonNull org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source)
- Parameters:
source
-
-
setListeners
public void setListeners(@NonNull @NonNull Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners)
This method allows you to specify trainingListeners for this model. Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead usingsetListeners(StatsStorageRouter, Collection)
- Parameters:
listeners
- Listeners to set
-
setListeners
public void setListeners(@NonNull @NonNull org.deeplearning4j.optimize.api.TrainingListener... listeners)
This method allows you to specify trainingListeners for this model. Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead usingsetListeners(StatsStorageRouter, Collection)
- Parameters:
listeners
- Listeners to set
-
setListeners
public void setListeners(StatsStorageRouter statsStorage, org.deeplearning4j.optimize.api.TrainingListener... listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListener
interface)- Parameters:
statsStorage
- Stats storage router to place the results intolisteners
- Listeners to set
-
setListeners
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends org.deeplearning4j.optimize.api.TrainingListener> listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement theRoutingIterationListener
interface)- Parameters:
statsStorage
- Stats storage router to place the results intolisteners
- Listeners to set
-
broadcastGradients
public void broadcastGradients(org.deeplearning4j.optimize.listeners.SharedGradient gradients)
This method will propagate gradients across all workers- Parameters:
gradients
-
-
fit
public void fit(@NonNull @NonNull org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)
This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors- Parameters:
source
-
-
-