Class DefaultTrainer
- java.lang.Object
-
- java.lang.Thread
-
- org.deeplearning4j.parallelism.trainer.DefaultTrainer
-
- Direct Known Subclasses:
SymmetricTrainer
public class DefaultTrainer extends Thread implements Trainer
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classDefaultTrainer.DefaultTrainerBuilder-
Nested classes/interfaces inherited from class java.lang.Thread
Thread.State, Thread.UncaughtExceptionHandler
-
-
Field Summary
Fields Modifier and Type Field Description protected intaveragingFrequencyprotected AtomicBooleanisStoppedprotected AtomicLonglastEtlTimeprotected ReentrantReadWriteLockmodelLockprotected org.nd4j.linalg.dataset.api.DataSetnullDataSetprotected AtomicBooleannullModeprotected booleanonRootModelprotected org.deeplearning4j.nn.api.ModeloriginalModelprotected ParallelWrapperparallelWrapperprotected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet>queueprotected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet>queueMDSprotected org.deeplearning4j.nn.api.ModelreplicatedModelprotected AtomicIntegerrunningprotected AtomicBooleanshouldStopprotected AtomicBooleanshouldUpdateprotected intthreadIdprotected ExceptionthrownExceptionprotected booleanuseMDSprotected Stringuuidprotected org.deeplearning4j.nn.conf.WorkspaceModeworkspaceMode-
Fields inherited from class java.lang.Thread
MAX_PRIORITY, MIN_PRIORITY, NORM_PRIORITY
-
-
Constructor Summary
Constructors Constructor Description DefaultTrainer()
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description booleanaveragingRequired()This method returns TRUE if this Trainer implementation assumes periodic averprotected static org.deeplearning4j.optimize.api.TrainingListenercloneListener(org.deeplearning4j.optimize.api.TrainingListener original)protected voidconfigureListeners(String workerUUID, Collection<org.deeplearning4j.optimize.api.TrainingListener> oldListeners, Collection<org.deeplearning4j.optimize.api.TrainingListener> replicatedListeners)voidfeedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)Train on aDataSetvoidfeedMultiDataSet(@NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)Train on aMultiDataSetprotected voidfit(org.nd4j.linalg.dataset.api.DataSet dataSet)protected voidfit(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)org.deeplearning4j.nn.api.ModelgetModel()THe current model for the trainerbooleanisRunning()protected voidpostInit()This method does post-initialization configuration of Model.voidrun()protected voidsetupIfNeccessary()voidshutdown()Shutdown this workervoidupdateModel(@NonNull org.deeplearning4j.nn.api.Model model)Update the currentModelfor the workervoidupdateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)This method updates replicated model paramsvoidupdateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)This method updates updater params of the replicated modelvoidwaitTillRunning()Block the main thread till the trainer is up and running.-
Methods inherited from class java.lang.Thread
activeCount, checkAccess, clone, countStackFrames, currentThread, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, onSpinWait, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, suspend, toString, yield
-
Methods inherited from class java.lang.Object
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.parallelism.trainer.Trainer
getUuid, setUncaughtExceptionHandler, start
-
-
-
-
Field Detail
-
replicatedModel
protected org.deeplearning4j.nn.api.Model replicatedModel
-
queue
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue
-
queueMDS
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queueMDS
-
running
protected AtomicInteger running
-
shouldUpdate
protected AtomicBoolean shouldUpdate
-
shouldStop
protected AtomicBoolean shouldStop
-
thrownException
protected Exception thrownException
-
useMDS
protected volatile boolean useMDS
-
uuid
protected String uuid
-
onRootModel
protected boolean onRootModel
-
lastEtlTime
protected volatile AtomicLong lastEtlTime
-
nullMode
protected AtomicBoolean nullMode
-
nullDataSet
protected org.nd4j.linalg.dataset.api.DataSet nullDataSet
-
isStopped
protected AtomicBoolean isStopped
-
parallelWrapper
protected ParallelWrapper parallelWrapper
-
workspaceMode
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
-
averagingFrequency
protected int averagingFrequency
-
threadId
protected int threadId
-
originalModel
protected org.deeplearning4j.nn.api.Model originalModel
-
modelLock
protected final ReentrantReadWriteLock modelLock
-
-
Method Detail
-
feedMultiDataSet
public void feedMultiDataSet(@NonNull @NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)Description copied from interface:TrainerTrain on aMultiDataSet- Specified by:
feedMultiDataSetin interfaceTrainer- Parameters:
dataSet- the data set to train on
-
feedDataSet
public void feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)Description copied from interface:TrainerTrain on aDataSet- Specified by:
feedDataSetin interfaceTrainer- Parameters:
dataSet- the data set to train on
-
getModel
public org.deeplearning4j.nn.api.Model getModel()
Description copied from interface:TrainerTHe current model for the trainer
-
updateModel
public void updateModel(@NonNull @NonNull org.deeplearning4j.nn.api.Model model)Description copied from interface:TrainerUpdate the currentModelfor the worker- Specified by:
updateModelin interfaceTrainer- Parameters:
model- the new model for this worker
-
setupIfNeccessary
protected void setupIfNeccessary()
-
shutdown
public void shutdown()
Description copied from interface:TrainerShutdown this worker
-
fit
protected void fit(org.nd4j.linalg.dataset.api.DataSet dataSet)
-
fit
protected void fit(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
-
postInit
protected void postInit()
This method does post-initialization configuration of Model. Good place to configure listeners and all such a things
-
waitTillRunning
public void waitTillRunning()
Description copied from interface:TrainerBlock the main thread till the trainer is up and running.- Specified by:
waitTillRunningin interfaceTrainer
-
updateModelParams
public void updateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)
Description copied from interface:TrainerThis method updates replicated model params- Specified by:
updateModelParamsin interfaceTrainer
-
updateUpdaterParams
public void updateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)
Description copied from interface:TrainerThis method updates updater params of the replicated model- Specified by:
updateUpdaterParamsin interfaceTrainer
-
averagingRequired
public boolean averagingRequired()
Description copied from interface:TrainerThis method returns TRUE if this Trainer implementation assumes periodic aver- Specified by:
averagingRequiredin interfaceTrainer- Returns:
-
cloneListener
protected static org.deeplearning4j.optimize.api.TrainingListener cloneListener(org.deeplearning4j.optimize.api.TrainingListener original)
-
configureListeners
protected void configureListeners(String workerUUID, Collection<org.deeplearning4j.optimize.api.TrainingListener> oldListeners, Collection<org.deeplearning4j.optimize.api.TrainingListener> replicatedListeners)
-
-