Class SymmetricTrainerContext
- java.lang.Object
-
- org.deeplearning4j.parallelism.factory.SymmetricTrainerContext
-
- All Implemented Interfaces:
TrainerContext
public class SymmetricTrainerContext extends Object implements TrainerContext
-
-
Constructor Summary
Constructors Constructor Description SymmetricTrainerContext()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description Trainer
create(String uuid, int threadId, org.deeplearning4j.nn.api.Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, org.deeplearning4j.nn.conf.WorkspaceMode mode, int averagingFrequency)
Create aTrainer
based on the given parametersvoid
finalizeRound(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)
This method is called at averagingFrequencyvoid
finalizeTraining(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)
This method is calledvoid
init(org.deeplearning4j.nn.api.Model model, Object... args)
Initialize the context
-
-
-
Method Detail
-
init
public void init(org.deeplearning4j.nn.api.Model model, Object... args)
Initialize the context- Specified by:
init
in interfaceTrainerContext
- Parameters:
model
-args
- the arguments to initialize with (maybe null)
-
create
public Trainer create(String uuid, int threadId, org.deeplearning4j.nn.api.Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, org.deeplearning4j.nn.conf.WorkspaceMode mode, int averagingFrequency)
Create aTrainer
based on the given parameters- Specified by:
create
in interfaceTrainerContext
- Parameters:
threadId
- the thread id to use for this workermodel
- the model to start the trainer withrootDevice
- the root device iduseMDS
- whether to use MultiDataSet or DataSet or notwrapper
- the wrapper instance to use with this trainer (this refernece is needed for coordination with theParallelWrapper
'sTrainingListener
- Returns:
- the created training instance
-
finalizeRound
public void finalizeRound(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)
Description copied from interface:TrainerContext
This method is called at averagingFrequency- Specified by:
finalizeRound
in interfaceTrainerContext
-
finalizeTraining
public void finalizeTraining(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)
Description copied from interface:TrainerContext
This method is called- Specified by:
finalizeTraining
in interfaceTrainerContext
-
-