Class DefaultTrainer
- java.lang.Object
-
- java.lang.Thread
-
- org.deeplearning4j.parallelism.trainer.DefaultTrainer
-
- Direct Known Subclasses:
SymmetricTrainer
public class DefaultTrainer extends Thread implements Trainer
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
DefaultTrainer.DefaultTrainerBuilder
-
Nested classes/interfaces inherited from class java.lang.Thread
Thread.State, Thread.UncaughtExceptionHandler
-
-
Field Summary
Fields Modifier and Type Field Description protected int
averagingFrequency
protected AtomicBoolean
isStopped
protected AtomicLong
lastEtlTime
protected ReentrantReadWriteLock
modelLock
protected org.nd4j.linalg.dataset.api.DataSet
nullDataSet
protected AtomicBoolean
nullMode
protected boolean
onRootModel
protected org.deeplearning4j.nn.api.Model
originalModel
protected ParallelWrapper
parallelWrapper
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet>
queue
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet>
queueMDS
protected org.deeplearning4j.nn.api.Model
replicatedModel
protected AtomicInteger
running
protected AtomicBoolean
shouldStop
protected AtomicBoolean
shouldUpdate
protected int
threadId
protected Exception
thrownException
protected boolean
useMDS
protected String
uuid
protected org.deeplearning4j.nn.conf.WorkspaceMode
workspaceMode
-
Fields inherited from class java.lang.Thread
MAX_PRIORITY, MIN_PRIORITY, NORM_PRIORITY
-
-
Constructor Summary
Constructors Constructor Description DefaultTrainer()
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description boolean
averagingRequired()
This method returns TRUE if this Trainer implementation assumes periodic averprotected static org.deeplearning4j.optimize.api.TrainingListener
cloneListener(org.deeplearning4j.optimize.api.TrainingListener original)
protected void
configureListeners(String workerUUID, Collection<org.deeplearning4j.optimize.api.TrainingListener> oldListeners, Collection<org.deeplearning4j.optimize.api.TrainingListener> replicatedListeners)
void
feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)
Train on aDataSet
void
feedMultiDataSet(@NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)
Train on aMultiDataSet
protected void
fit(org.nd4j.linalg.dataset.api.DataSet dataSet)
protected void
fit(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
org.deeplearning4j.nn.api.Model
getModel()
THe current model for the trainerboolean
isRunning()
protected void
postInit()
This method does post-initialization configuration of Model.void
run()
protected void
setupIfNeccessary()
void
shutdown()
Shutdown this workervoid
updateModel(@NonNull org.deeplearning4j.nn.api.Model model)
Update the currentModel
for the workervoid
updateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates replicated model paramsvoid
updateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates updater params of the replicated modelvoid
waitTillRunning()
Block the main thread till the trainer is up and running.-
Methods inherited from class java.lang.Thread
activeCount, checkAccess, clone, countStackFrames, currentThread, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, onSpinWait, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, suspend, toString, yield
-
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.parallelism.trainer.Trainer
getUuid, setUncaughtExceptionHandler, start
-
-
-
-
Field Detail
-
replicatedModel
protected org.deeplearning4j.nn.api.Model replicatedModel
-
queue
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue
-
queueMDS
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queueMDS
-
running
protected AtomicInteger running
-
shouldUpdate
protected AtomicBoolean shouldUpdate
-
shouldStop
protected AtomicBoolean shouldStop
-
thrownException
protected Exception thrownException
-
useMDS
protected volatile boolean useMDS
-
uuid
protected String uuid
-
onRootModel
protected boolean onRootModel
-
lastEtlTime
protected volatile AtomicLong lastEtlTime
-
nullMode
protected AtomicBoolean nullMode
-
nullDataSet
protected org.nd4j.linalg.dataset.api.DataSet nullDataSet
-
isStopped
protected AtomicBoolean isStopped
-
parallelWrapper
protected ParallelWrapper parallelWrapper
-
workspaceMode
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
-
averagingFrequency
protected int averagingFrequency
-
threadId
protected int threadId
-
originalModel
protected org.deeplearning4j.nn.api.Model originalModel
-
modelLock
protected final ReentrantReadWriteLock modelLock
-
-
Method Detail
-
feedMultiDataSet
public void feedMultiDataSet(@NonNull @NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)
Description copied from interface:Trainer
Train on aMultiDataSet
- Specified by:
feedMultiDataSet
in interfaceTrainer
- Parameters:
dataSet
- the data set to train on
-
feedDataSet
public void feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)
Description copied from interface:Trainer
Train on aDataSet
- Specified by:
feedDataSet
in interfaceTrainer
- Parameters:
dataSet
- the data set to train on
-
getModel
public org.deeplearning4j.nn.api.Model getModel()
Description copied from interface:Trainer
THe current model for the trainer
-
updateModel
public void updateModel(@NonNull @NonNull org.deeplearning4j.nn.api.Model model)
Description copied from interface:Trainer
Update the currentModel
for the worker- Specified by:
updateModel
in interfaceTrainer
- Parameters:
model
- the new model for this worker
-
setupIfNeccessary
protected void setupIfNeccessary()
-
shutdown
public void shutdown()
Description copied from interface:Trainer
Shutdown this worker
-
fit
protected void fit(org.nd4j.linalg.dataset.api.DataSet dataSet)
-
fit
protected void fit(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
-
postInit
protected void postInit()
This method does post-initialization configuration of Model. Good place to configure listeners and all such a things
-
waitTillRunning
public void waitTillRunning()
Description copied from interface:Trainer
Block the main thread till the trainer is up and running.- Specified by:
waitTillRunning
in interfaceTrainer
-
updateModelParams
public void updateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)
Description copied from interface:Trainer
This method updates replicated model params- Specified by:
updateModelParams
in interfaceTrainer
-
updateUpdaterParams
public void updateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)
Description copied from interface:Trainer
This method updates updater params of the replicated model- Specified by:
updateUpdaterParams
in interfaceTrainer
-
averagingRequired
public boolean averagingRequired()
Description copied from interface:Trainer
This method returns TRUE if this Trainer implementation assumes periodic aver- Specified by:
averagingRequired
in interfaceTrainer
- Returns:
-
cloneListener
protected static org.deeplearning4j.optimize.api.TrainingListener cloneListener(org.deeplearning4j.optimize.api.TrainingListener original)
-
configureListeners
protected void configureListeners(String workerUUID, Collection<org.deeplearning4j.optimize.api.TrainingListener> oldListeners, Collection<org.deeplearning4j.optimize.api.TrainingListener> replicatedListeners)
-
-